Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Merge pull request #20 from argilla-io/feat/improve-support-local-deployment
Browse files- README.md +11 -7
- examples/{argilla_deployment.py โ argilla-deployment.py} +5 -2
- examples/fine-tune-modernbert-classifier.ipynb +1 -1
- examples/{openai_local.py โ hf-dedicated-or-tgi-deployment.py} +7 -4
- examples/{enforce_mapgie_template copy.py โ hf-serverless-deployment.py} +3 -2
- examples/ollama-deployment.py +22 -0
- examples/{ollama_local.py โ openai-deployment.py} +6 -3
- examples/vllm-deployment.py +21 -0
- pdm.lock +51 -16
- pyproject.toml +1 -1
- src/synthetic_dataset_generator/_distiset.py +10 -0
- src/synthetic_dataset_generator/apps/base.py +7 -2
- src/synthetic_dataset_generator/apps/chat.py +8 -11
- src/synthetic_dataset_generator/apps/eval.py +3 -1
- src/synthetic_dataset_generator/apps/textcat.py +41 -20
- src/synthetic_dataset_generator/constants.py +53 -25
- src/synthetic_dataset_generator/pipelines/base.py +132 -1
- src/synthetic_dataset_generator/pipelines/chat.py +36 -72
- src/synthetic_dataset_generator/pipelines/textcat.py +12 -52
- src/synthetic_dataset_generator/utils.py +5 -0
    	
        README.md
    CHANGED
    
    | @@ -28,7 +28,7 @@ hf_oauth_scopes: | |
| 28 |  | 
| 29 | 
             
            ## Introduction
         | 
| 30 |  | 
| 31 | 
            -
            Synthetic Data Generator is a tool that allows you to create high-quality datasets for training and fine-tuning language models. It leverages the power of distilabel and LLMs to generate synthetic data tailored to your specific needs. [The announcement blog](https://huggingface.co/blog/synthetic-data-generator) goes over a practical example of how to use it but you can also  | 
| 32 |  | 
| 33 | 
             
            Supported Tasks:
         | 
| 34 |  | 
| @@ -76,21 +76,25 @@ launch() | |
| 76 |  | 
| 77 | 
             
            - `HF_TOKEN`: Your [Hugging Face token](https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&tokenType=fineGrained) to push your datasets to the Hugging Face Hub and generate free completions from Hugging Face Inference Endpoints. You can find some configuration examples in the [examples](examples/) folder.
         | 
| 78 |  | 
| 79 | 
            -
             | 
| 80 |  | 
| 81 | 
             
            - `MAX_NUM_TOKENS`: The maximum number of tokens to generate, defaults to `2048`.
         | 
| 82 | 
             
            - `MAX_NUM_ROWS`: The maximum number of rows to generate, defaults to `1000`.
         | 
| 83 | 
             
            - `DEFAULT_BATCH_SIZE`: The default batch size to use for generating the dataset, defaults to `5`.
         | 
| 84 |  | 
| 85 | 
            -
            Optionally, you can use different  | 
| 86 |  | 
| 87 | 
            -
            - ` | 
| 88 | 
            -
            - `MODEL`: The model to use for generating the dataset, e.g. `meta-llama/Meta-Llama-3.1-8B-Instruct`, `openai/gpt-4o`, `ollama/llama3.1`.
         | 
| 89 | 
             
            - `API_KEY`: The API key to use for the generation API, e.g. `hf_...`, `sk-...`. If not provided, it will default to the provided `HF_TOKEN` environment variable.
         | 
|  | |
|  | |
|  | |
|  | |
| 90 |  | 
| 91 | 
            -
            SFT and Chat Data generation is  | 
| 92 |  | 
| 93 | 
            -
            - ` | 
|  | |
| 94 |  | 
| 95 | 
             
            Optionally, you can also push your datasets to Argilla for further curation by setting the following environment variables:
         | 
| 96 |  | 
|  | |
| 28 |  | 
| 29 | 
             
            ## Introduction
         | 
| 30 |  | 
| 31 | 
            +
            Synthetic Data Generator is a tool that allows you to create high-quality datasets for training and fine-tuning language models. It leverages the power of distilabel and LLMs to generate synthetic data tailored to your specific needs. [The announcement blog](https://huggingface.co/blog/synthetic-data-generator) goes over a practical example of how to use it but you can also watch the [video](https://www.youtube.com/watch?v=nXjVtnGeEss) to see it in action.
         | 
| 32 |  | 
| 33 | 
             
            Supported Tasks:
         | 
| 34 |  | 
|  | |
| 76 |  | 
| 77 | 
             
            - `HF_TOKEN`: Your [Hugging Face token](https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&tokenType=fineGrained) to push your datasets to the Hugging Face Hub and generate free completions from Hugging Face Inference Endpoints. You can find some configuration examples in the [examples](examples/) folder.
         | 
| 78 |  | 
| 79 | 
            +
            You can set the following environment variables to customize the generation process.
         | 
| 80 |  | 
| 81 | 
             
            - `MAX_NUM_TOKENS`: The maximum number of tokens to generate, defaults to `2048`.
         | 
| 82 | 
             
            - `MAX_NUM_ROWS`: The maximum number of rows to generate, defaults to `1000`.
         | 
| 83 | 
             
            - `DEFAULT_BATCH_SIZE`: The default batch size to use for generating the dataset, defaults to `5`.
         | 
| 84 |  | 
| 85 | 
            +
            Optionally, you can use different API providers and models.
         | 
| 86 |  | 
| 87 | 
            +
            - `MODEL`: The model to use for generating the dataset, e.g. `meta-llama/Meta-Llama-3.1-8B-Instruct`, `gpt-4o`, `llama3.1`.
         | 
|  | |
| 88 | 
             
            - `API_KEY`: The API key to use for the generation API, e.g. `hf_...`, `sk-...`. If not provided, it will default to the provided `HF_TOKEN` environment variable.
         | 
| 89 | 
            +
            - `OPENAI_BASE_URL`: The base URL for any OpenAI compatible API, e.g. `https://api.openai.com/v1/`.
         | 
| 90 | 
            +
            - `OLLAMA_BASE_URL`: The base URL for any Ollama compatible API, e.g. `http://127.0.0.1:11434/`.
         | 
| 91 | 
            +
            - `HUGGINGFACE_BASE_URL`: The base URL for any Hugging Face compatible API, e.g. TGI server or Dedicated Inference Endpoints. If you want to use serverless inference, only set the `MODEL`.
         | 
| 92 | 
            +
            - `VLLM_BASE_URL`: The base URL for any VLLM compatible API, e.g. `http://localhost:8000/`.
         | 
| 93 |  | 
| 94 | 
            +
            SFT and Chat Data generation is not supported with OpenAI Endpoints. Additionally, you need to configure it per model family based on their prompt templates using the right `TOKENIZER_ID` and `MAGPIE_PRE_QUERY_TEMPLATE` environment variables.
         | 
| 95 |  | 
| 96 | 
            +
            - `TOKENIZER_ID`: The tokenizer ID to use for the magpie pipeline, e.g. `meta-llama/Meta-Llama-3.1-8B-Instruct`.
         | 
| 97 | 
            +
            - `MAGPIE_PRE_QUERY_TEMPLATE`: Enforce setting the pre-query template for Magpie, which is only supported with Hugging Face Inference Endpoints. `llama3` and `qwen2` are supported out of the box and will use `"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"` and `"<|im_start|>user\n"`, respectively. For other models, you can pass a custom pre-query template string.
         | 
| 98 |  | 
| 99 | 
             
            Optionally, you can also push your datasets to Argilla for further curation by setting the following environment variables:
         | 
| 100 |  | 
    	
        examples/{argilla_deployment.py โ argilla-deployment.py}
    RENAMED
    
    | @@ -9,7 +9,10 @@ import os | |
| 9 | 
             
            from synthetic_dataset_generator import launch
         | 
| 10 |  | 
| 11 | 
             
            # Follow https://docs.argilla.io/latest/getting_started/quickstart/ to get your Argilla API key and URL
         | 
| 12 | 
            -
            os.environ[" | 
| 13 | 
            -
            os.environ[" | 
|  | |
|  | |
|  | |
| 14 |  | 
| 15 | 
             
            launch()
         | 
|  | |
| 9 | 
             
            from synthetic_dataset_generator import launch
         | 
| 10 |  | 
| 11 | 
             
            # Follow https://docs.argilla.io/latest/getting_started/quickstart/ to get your Argilla API key and URL
         | 
| 12 | 
            +
            os.environ["HF_TOKEN"] = "hf_..."
         | 
| 13 | 
            +
            os.environ["ARGILLA_API_URL"] = (
         | 
| 14 | 
            +
                "https://[your-owner-name]-[your_space_name].hf.space"  # argilla base url
         | 
| 15 | 
            +
            )
         | 
| 16 | 
            +
            os.environ["ARGILLA_API_KEY"] = "my_api_key"  # argilla api key
         | 
| 17 |  | 
| 18 | 
             
            launch()
         | 
    	
        examples/fine-tune-modernbert-classifier.ipynb
    CHANGED
    
    | @@ -530,7 +530,7 @@ | |
| 530 | 
             
               "name": "python",
         | 
| 531 | 
             
               "nbconvert_exporter": "python",
         | 
| 532 | 
             
               "pygments_lexer": "ipython3",
         | 
| 533 | 
            -
               "version": "3.11. | 
| 534 | 
             
              }
         | 
| 535 | 
             
             },
         | 
| 536 | 
             
             "nbformat": 4,
         | 
|  | |
| 530 | 
             
               "name": "python",
         | 
| 531 | 
             
               "nbconvert_exporter": "python",
         | 
| 532 | 
             
               "pygments_lexer": "ipython3",
         | 
| 533 | 
            +
               "version": "3.11.11"
         | 
| 534 | 
             
              }
         | 
| 535 | 
             
             },
         | 
| 536 | 
             
             "nbformat": 4,
         | 
    	
        examples/{openai_local.py โ hf-dedicated-or-tgi-deployment.py}
    RENAMED
    
    | @@ -8,9 +8,12 @@ import os | |
| 8 |  | 
| 9 | 
             
            from synthetic_dataset_generator import launch
         | 
| 10 |  | 
| 11 | 
            -
             | 
| 12 | 
            -
            os.environ[" | 
| 13 | 
            -
            os.environ[" | 
| 14 | 
            -
            os.environ[" | 
|  | |
|  | |
|  | |
| 15 |  | 
| 16 | 
             
            launch()
         | 
|  | |
| 8 |  | 
| 9 | 
             
            from synthetic_dataset_generator import launch
         | 
| 10 |  | 
| 11 | 
            +
            os.environ["HF_TOKEN"] = "hf_..."  # push the data to huggingface
         | 
| 12 | 
            +
            os.environ["HUGGINGFACE_BASE_URL"] = "http://127.0.0.1:3000/"  # dedicated endpoint/TGI
         | 
| 13 | 
            +
            os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "llama3"  # magpie template
         | 
| 14 | 
            +
            os.environ["TOKENIZER_ID"] = (
         | 
| 15 | 
            +
                "meta-llama/Llama-3.1-8B-Instruct"  # tokenizer for model hosted on endpoint
         | 
| 16 | 
            +
            )
         | 
| 17 | 
            +
            os.environ["MODEL"] = None  # model is linked to endpoint
         | 
| 18 |  | 
| 19 | 
             
            launch()
         | 
    	
        examples/{enforce_mapgie_template copy.py โ hf-serverless-deployment.py}
    RENAMED
    
    | @@ -8,7 +8,8 @@ import os | |
| 8 |  | 
| 9 | 
             
            from synthetic_dataset_generator import launch
         | 
| 10 |  | 
| 11 | 
            -
            os.environ[" | 
| 12 | 
            -
            os.environ["MODEL"] = " | 
|  | |
| 13 |  | 
| 14 | 
             
            launch()
         | 
|  | |
| 8 |  | 
| 9 | 
             
            from synthetic_dataset_generator import launch
         | 
| 10 |  | 
| 11 | 
            +
            os.environ["HF_TOKEN"] = "hf_..."  # push the data to huggingface
         | 
| 12 | 
            +
            os.environ["MODEL"] = "meta-llama/Llama-3.1-8B-Instruct"  # use instruct model
         | 
| 13 | 
            +
            os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "llama3"  # use the template for the model
         | 
| 14 |  | 
| 15 | 
             
            launch()
         | 
    	
        examples/ollama-deployment.py
    ADDED
    
    | @@ -0,0 +1,22 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # /// script
         | 
| 2 | 
            +
            # requires-python = ">=3.11,<3.12"
         | 
| 3 | 
            +
            # dependencies = [
         | 
| 4 | 
            +
            #     "synthetic-dataset-generator",
         | 
| 5 | 
            +
            # ]
         | 
| 6 | 
            +
            # ///
         | 
| 7 | 
            +
            # ollama serve
         | 
| 8 | 
            +
            # ollama run qwen2.5:32b-instruct-q5_K_S
         | 
| 9 | 
            +
            import os
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            from synthetic_dataset_generator import launch
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            os.environ["HF_TOKEN"] = "hf_..."  # push the data to huggingface
         | 
| 14 | 
            +
            os.environ["OLLAMA_BASE_URL"] = "http://127.0.0.1:11434/"  # ollama base url
         | 
| 15 | 
            +
            os.environ["MODEL"] = "qwen2.5:32b-instruct-q5_K_S"  # model id
         | 
| 16 | 
            +
            os.environ["TOKENIZER_ID"] = "Qwen/Qwen2.5-32B-Instruct"  # tokenizer id
         | 
| 17 | 
            +
            os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "qwen2"
         | 
| 18 | 
            +
            os.environ["MAX_NUM_ROWS"] = "10000"
         | 
| 19 | 
            +
            os.environ["DEFAULT_BATCH_SIZE"] = "2"
         | 
| 20 | 
            +
            os.environ["MAX_NUM_TOKENS"] = "1024"
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            launch()
         | 
    	
        examples/{ollama_local.py โ openai-deployment.py}
    RENAMED
    
    | @@ -4,12 +4,15 @@ | |
| 4 | 
             
            #     "synthetic-dataset-generator",
         | 
| 5 | 
             
            # ]
         | 
| 6 | 
             
            # ///
         | 
|  | |
| 7 | 
             
            import os
         | 
| 8 |  | 
| 9 | 
             
            from synthetic_dataset_generator import launch
         | 
| 10 |  | 
| 11 | 
            -
             | 
| 12 | 
            -
            os.environ[" | 
| 13 | 
            -
            os.environ[" | 
|  | |
|  | |
| 14 |  | 
| 15 | 
             
            launch()
         | 
|  | |
| 4 | 
             
            #     "synthetic-dataset-generator",
         | 
| 5 | 
             
            # ]
         | 
| 6 | 
             
            # ///
         | 
| 7 | 
            +
             | 
| 8 | 
             
            import os
         | 
| 9 |  | 
| 10 | 
             
            from synthetic_dataset_generator import launch
         | 
| 11 |  | 
| 12 | 
            +
            os.environ["HF_TOKEN"] = "hf_..."  # push the data to huggingface
         | 
| 13 | 
            +
            os.environ["OPENAI_BASE_URL"] = "https://api.openai.com/v1/"  # openai base url
         | 
| 14 | 
            +
            os.environ["API_KEY"] = os.getenv("OPENAI_API_KEY")  # openai api key
         | 
| 15 | 
            +
            os.environ["MODEL"] = "gpt-4o"  # model id
         | 
| 16 | 
            +
            os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = None  # chat data not supported with OpenAI
         | 
| 17 |  | 
| 18 | 
             
            launch()
         | 
    	
        examples/vllm-deployment.py
    ADDED
    
    | @@ -0,0 +1,21 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # /// script
         | 
| 2 | 
            +
            # requires-python = ">=3.11,<3.12"
         | 
| 3 | 
            +
            # dependencies = [
         | 
| 4 | 
            +
            #     "synthetic-dataset-generator",
         | 
| 5 | 
            +
            # ]
         | 
| 6 | 
            +
            # ///
         | 
| 7 | 
            +
            # vllm serve Qwen/Qwen2.5-1.5B-Instruct
         | 
| 8 | 
            +
            import os
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            from synthetic_dataset_generator import launch
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            os.environ["HF_TOKEN"] = "hf_..."  # push the data to huggingface
         | 
| 13 | 
            +
            os.environ["VLLM_BASE_URL"] = "http://127.0.0.1:8000/"  # vllm base url
         | 
| 14 | 
            +
            os.environ["MODEL"] = "Qwen/Qwen2.5-1.5B-Instruct"  # model id
         | 
| 15 | 
            +
            os.environ["TOKENIZER_ID"] = "Qwen/Qwen2.5-1.5B-Instruct"  # tokenizer id
         | 
| 16 | 
            +
            os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "qwen2"
         | 
| 17 | 
            +
            os.environ["MAX_NUM_ROWS"] = "10000"
         | 
| 18 | 
            +
            os.environ["DEFAULT_BATCH_SIZE"] = "2"
         | 
| 19 | 
            +
            os.environ["MAX_NUM_TOKENS"] = "1024"
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            launch()
         | 
    	
        pdm.lock
    CHANGED
    
    | @@ -5,7 +5,7 @@ | |
| 5 | 
             
            groups = ["default"]
         | 
| 6 | 
             
            strategy = ["inherit_metadata"]
         | 
| 7 | 
             
            lock_version = "4.5.0"
         | 
| 8 | 
            -
            content_hash = "sha256: | 
| 9 |  | 
| 10 | 
             
            [[metadata.targets]]
         | 
| 11 | 
             
            requires_python = ">=3.10,<3.13"
         | 
| @@ -491,8 +491,11 @@ files = [ | |
| 491 |  | 
| 492 | 
             
            [[package]]
         | 
| 493 | 
             
            name = "distilabel"
         | 
| 494 | 
            -
            version = "1. | 
| 495 | 
             
            requires_python = ">=3.9"
         | 
|  | |
|  | |
|  | |
| 496 | 
             
            summary = "Distilabel is an AI Feedback (AIF) framework for building datasets with and for LLMs."
         | 
| 497 | 
             
            groups = ["default"]
         | 
| 498 | 
             
            dependencies = [
         | 
| @@ -512,30 +515,30 @@ dependencies = [ | |
| 512 | 
             
                "typer>=0.9.0",
         | 
| 513 | 
             
                "universal-pathlib>=0.2.2",
         | 
| 514 | 
             
            ]
         | 
| 515 | 
            -
            files = [
         | 
| 516 | 
            -
                {file = "distilabel-1.4.1-py3-none-any.whl", hash = "sha256:4643da7f3abae86a330d86d1498443ea56978e462e21ae3d106a4c6013386965"},
         | 
| 517 | 
            -
                {file = "distilabel-1.4.1.tar.gz", hash = "sha256:0c373be234e8f2982ec7f940d9a95585b15306b6ab5315f5a6a45214d8f34006"},
         | 
| 518 | 
            -
            ]
         | 
| 519 |  | 
| 520 | 
             
            [[package]]
         | 
| 521 | 
             
            name = "distilabel"
         | 
| 522 | 
            -
            version = "1. | 
| 523 | 
            -
            extras = ["argilla", "hf-inference-endpoints", "instructor", "outlines"]
         | 
| 524 | 
             
            requires_python = ">=3.9"
         | 
|  | |
|  | |
|  | |
| 525 | 
             
            summary = "Distilabel is an AI Feedback (AIF) framework for building datasets with and for LLMs."
         | 
| 526 | 
             
            groups = ["default"]
         | 
| 527 | 
             
            dependencies = [
         | 
| 528 | 
             
                "argilla>=2.0.0",
         | 
| 529 | 
            -
                "distilabel | 
| 530 | 
             
                "huggingface-hub>=0.22.0",
         | 
| 531 | 
             
                "instructor>=1.2.3",
         | 
| 532 | 
             
                "ipython",
         | 
|  | |
| 533 | 
             
                "numba>=0.54.0",
         | 
|  | |
|  | |
| 534 | 
             
                "outlines>=0.0.40",
         | 
| 535 | 
            -
             | 
| 536 | 
            -
             | 
| 537 | 
            -
                {file = "distilabel-1.4.1-py3-none-any.whl", hash = "sha256:4643da7f3abae86a330d86d1498443ea56978e462e21ae3d106a4c6013386965"},
         | 
| 538 | 
            -
                {file = "distilabel-1.4.1.tar.gz", hash = "sha256:0c373be234e8f2982ec7f940d9a95585b15306b6ab5315f5a6a45214d8f34006"},
         | 
| 539 | 
             
            ]
         | 
| 540 |  | 
| 541 | 
             
            [[package]]
         | 
| @@ -824,7 +827,7 @@ files = [ | |
| 824 |  | 
| 825 | 
             
            [[package]]
         | 
| 826 | 
             
            name = "httpx"
         | 
| 827 | 
            -
            version = "0. | 
| 828 | 
             
            requires_python = ">=3.8"
         | 
| 829 | 
             
            summary = "The next generation HTTP client."
         | 
| 830 | 
             
            groups = ["default"]
         | 
| @@ -833,10 +836,11 @@ dependencies = [ | |
| 833 | 
             
                "certifi",
         | 
| 834 | 
             
                "httpcore==1.*",
         | 
| 835 | 
             
                "idna",
         | 
|  | |
| 836 | 
             
            ]
         | 
| 837 | 
             
            files = [
         | 
| 838 | 
            -
                {file = "httpx-0. | 
| 839 | 
            -
                {file = "httpx-0. | 
| 840 | 
             
            ]
         | 
| 841 |  | 
| 842 | 
             
            [[package]]
         | 
| @@ -1068,6 +1072,22 @@ files = [ | |
| 1068 | 
             
                {file = "lark-1.2.2.tar.gz", hash = "sha256:ca807d0162cd16cef15a8feecb862d7319e7a09bdb13aef927968e45040fed80"},
         | 
| 1069 | 
             
            ]
         | 
| 1070 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 1071 | 
             
            [[package]]
         | 
| 1072 | 
             
            name = "llvmlite"
         | 
| 1073 | 
             
            version = "0.43.0"
         | 
| @@ -1538,6 +1558,21 @@ files = [ | |
| 1538 | 
             
                {file = "nvidia_nvtx_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:641dccaaa1139f3ffb0d3164b4b84f9d253397e38246a4f2f36728b48566d485"},
         | 
| 1539 | 
             
            ]
         | 
| 1540 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 1541 | 
             
            [[package]]
         | 
| 1542 | 
             
            name = "openai"
         | 
| 1543 | 
             
            version = "1.57.4"
         | 
|  | |
| 5 | 
             
            groups = ["default"]
         | 
| 6 | 
             
            strategy = ["inherit_metadata"]
         | 
| 7 | 
             
            lock_version = "4.5.0"
         | 
| 8 | 
            +
            content_hash = "sha256:e95140895657d62ad438ff1815ddf1798abbb342ddd2649ae462620b8b3f5350"
         | 
| 9 |  | 
| 10 | 
             
            [[metadata.targets]]
         | 
| 11 | 
             
            requires_python = ">=3.10,<3.13"
         | 
|  | |
| 491 |  | 
| 492 | 
             
            [[package]]
         | 
| 493 | 
             
            name = "distilabel"
         | 
| 494 | 
            +
            version = "1.5.0"
         | 
| 495 | 
             
            requires_python = ">=3.9"
         | 
| 496 | 
            +
            git = "https://github.com/argilla-io/distilabel.git"
         | 
| 497 | 
            +
            ref = "feat/add-magpie-support-llama-cpp-ollama"
         | 
| 498 | 
            +
            revision = "4e291e7bf1c27b734a683a3af1fefe58965d77d6"
         | 
| 499 | 
             
            summary = "Distilabel is an AI Feedback (AIF) framework for building datasets with and for LLMs."
         | 
| 500 | 
             
            groups = ["default"]
         | 
| 501 | 
             
            dependencies = [
         | 
|  | |
| 515 | 
             
                "typer>=0.9.0",
         | 
| 516 | 
             
                "universal-pathlib>=0.2.2",
         | 
| 517 | 
             
            ]
         | 
|  | |
|  | |
|  | |
|  | |
| 518 |  | 
| 519 | 
             
            [[package]]
         | 
| 520 | 
             
            name = "distilabel"
         | 
| 521 | 
            +
            version = "1.5.0"
         | 
| 522 | 
            +
            extras = ["argilla", "hf-inference-endpoints", "hf-transformers", "instructor", "llama-cpp", "ollama", "openai", "outlines"]
         | 
| 523 | 
             
            requires_python = ">=3.9"
         | 
| 524 | 
            +
            git = "https://github.com/argilla-io/distilabel.git"
         | 
| 525 | 
            +
            ref = "feat/add-magpie-support-llama-cpp-ollama"
         | 
| 526 | 
            +
            revision = "4e291e7bf1c27b734a683a3af1fefe58965d77d6"
         | 
| 527 | 
             
            summary = "Distilabel is an AI Feedback (AIF) framework for building datasets with and for LLMs."
         | 
| 528 | 
             
            groups = ["default"]
         | 
| 529 | 
             
            dependencies = [
         | 
| 530 | 
             
                "argilla>=2.0.0",
         | 
| 531 | 
            +
                "distilabel @ git+https://github.com/argilla-io/distilabel.git@feat/add-magpie-support-llama-cpp-ollama",
         | 
| 532 | 
             
                "huggingface-hub>=0.22.0",
         | 
| 533 | 
             
                "instructor>=1.2.3",
         | 
| 534 | 
             
                "ipython",
         | 
| 535 | 
            +
                "llama-cpp-python>=0.2.0",
         | 
| 536 | 
             
                "numba>=0.54.0",
         | 
| 537 | 
            +
                "ollama>=0.1.7",
         | 
| 538 | 
            +
                "openai>=1.0.0",
         | 
| 539 | 
             
                "outlines>=0.0.40",
         | 
| 540 | 
            +
                "torch>=2.0.0",
         | 
| 541 | 
            +
                "transformers>=4.34.1",
         | 
|  | |
|  | |
| 542 | 
             
            ]
         | 
| 543 |  | 
| 544 | 
             
            [[package]]
         | 
|  | |
| 827 |  | 
| 828 | 
             
            [[package]]
         | 
| 829 | 
             
            name = "httpx"
         | 
| 830 | 
            +
            version = "0.27.2"
         | 
| 831 | 
             
            requires_python = ">=3.8"
         | 
| 832 | 
             
            summary = "The next generation HTTP client."
         | 
| 833 | 
             
            groups = ["default"]
         | 
|  | |
| 836 | 
             
                "certifi",
         | 
| 837 | 
             
                "httpcore==1.*",
         | 
| 838 | 
             
                "idna",
         | 
| 839 | 
            +
                "sniffio",
         | 
| 840 | 
             
            ]
         | 
| 841 | 
             
            files = [
         | 
| 842 | 
            +
                {file = "httpx-0.27.2-py3-none-any.whl", hash = "sha256:7bb2708e112d8fdd7829cd4243970f0c223274051cb35ee80c03301ee29a3df0"},
         | 
| 843 | 
            +
                {file = "httpx-0.27.2.tar.gz", hash = "sha256:f7c2be1d2f3c3c3160d441802406b206c2b76f5947b11115e6df10c6c65e66c2"},
         | 
| 844 | 
             
            ]
         | 
| 845 |  | 
| 846 | 
             
            [[package]]
         | 
|  | |
| 1072 | 
             
                {file = "lark-1.2.2.tar.gz", hash = "sha256:ca807d0162cd16cef15a8feecb862d7319e7a09bdb13aef927968e45040fed80"},
         | 
| 1073 | 
             
            ]
         | 
| 1074 |  | 
| 1075 | 
            +
            [[package]]
         | 
| 1076 | 
            +
            name = "llama-cpp-python"
         | 
| 1077 | 
            +
            version = "0.3.5"
         | 
| 1078 | 
            +
            requires_python = ">=3.8"
         | 
| 1079 | 
            +
            summary = "Python bindings for the llama.cpp library"
         | 
| 1080 | 
            +
            groups = ["default"]
         | 
| 1081 | 
            +
            dependencies = [
         | 
| 1082 | 
            +
                "diskcache>=5.6.1",
         | 
| 1083 | 
            +
                "jinja2>=2.11.3",
         | 
| 1084 | 
            +
                "numpy>=1.20.0",
         | 
| 1085 | 
            +
                "typing-extensions>=4.5.0",
         | 
| 1086 | 
            +
            ]
         | 
| 1087 | 
            +
            files = [
         | 
| 1088 | 
            +
                {file = "llama_cpp_python-0.3.5.tar.gz", hash = "sha256:f5ce47499d53d3973e28ca5bdaf2dfe820163fa3fb67e3050f98e2e9b58d2cf6"},
         | 
| 1089 | 
            +
            ]
         | 
| 1090 | 
            +
             | 
| 1091 | 
             
            [[package]]
         | 
| 1092 | 
             
            name = "llvmlite"
         | 
| 1093 | 
             
            version = "0.43.0"
         | 
|  | |
| 1558 | 
             
                {file = "nvidia_nvtx_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:641dccaaa1139f3ffb0d3164b4b84f9d253397e38246a4f2f36728b48566d485"},
         | 
| 1559 | 
             
            ]
         | 
| 1560 |  | 
| 1561 | 
            +
            [[package]]
         | 
| 1562 | 
            +
            name = "ollama"
         | 
| 1563 | 
            +
            version = "0.4.4"
         | 
| 1564 | 
            +
            requires_python = "<4.0,>=3.8"
         | 
| 1565 | 
            +
            summary = "The official Python client for Ollama."
         | 
| 1566 | 
            +
            groups = ["default"]
         | 
| 1567 | 
            +
            dependencies = [
         | 
| 1568 | 
            +
                "httpx<0.28.0,>=0.27.0",
         | 
| 1569 | 
            +
                "pydantic<3.0.0,>=2.9.0",
         | 
| 1570 | 
            +
            ]
         | 
| 1571 | 
            +
            files = [
         | 
| 1572 | 
            +
                {file = "ollama-0.4.4-py3-none-any.whl", hash = "sha256:0f466e845e2205a1cbf5a2fef4640027b90beaa3b06c574426d8b6b17fd6e139"},
         | 
| 1573 | 
            +
                {file = "ollama-0.4.4.tar.gz", hash = "sha256:e1db064273c739babc2dde9ea84029c4a43415354741b6c50939ddd3dd0f7ffb"},
         | 
| 1574 | 
            +
            ]
         | 
| 1575 | 
            +
             | 
| 1576 | 
             
            [[package]]
         | 
| 1577 | 
             
            name = "openai"
         | 
| 1578 | 
             
            version = "1.57.4"
         | 
    	
        pyproject.toml
    CHANGED
    
    | @@ -18,7 +18,7 @@ readme = "README.md" | |
| 18 | 
             
            license = {text = "Apache 2"}
         | 
| 19 |  | 
| 20 | 
             
            dependencies = [
         | 
| 21 | 
            -
                "distilabel[hf-inference-endpoints, | 
| 22 | 
             
                "gradio[oauth]>=5.4.0,<6.0.0",
         | 
| 23 | 
             
                "transformers>=4.44.2,<5.0.0",
         | 
| 24 | 
             
                "sentence-transformers>=3.2.0,<4.0.0",
         | 
|  | |
| 18 | 
             
            license = {text = "Apache 2"}
         | 
| 19 |  | 
| 20 | 
             
            dependencies = [
         | 
| 21 | 
            +
                "distilabel[argilla,hf-inference-endpoints,hf-transformers,instructor,llama-cpp,ollama,openai,outlines,vllm] @ git+https://github.com/argilla-io/distilabel.git@develop",
         | 
| 22 | 
             
                "gradio[oauth]>=5.4.0,<6.0.0",
         | 
| 23 | 
             
                "transformers>=4.44.2,<5.0.0",
         | 
| 24 | 
             
                "sentence-transformers>=3.2.0,<4.0.0",
         | 
    	
        src/synthetic_dataset_generator/_distiset.py
    CHANGED
    
    | @@ -81,6 +81,15 @@ class CustomDistisetWithAdditionalTag(distilabel.distiset.Distiset): | |
| 81 | 
             
                            dataset[0] if not isinstance(dataset, dict) else dataset["train"][0]
         | 
| 82 | 
             
                        )
         | 
| 83 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 84 | 
             
                    readme_metadata = {}
         | 
| 85 | 
             
                    if repo_id and token:
         | 
| 86 | 
             
                        readme_metadata = self._extract_readme_metadata(repo_id, token)
         | 
| @@ -90,6 +99,7 @@ class CustomDistisetWithAdditionalTag(distilabel.distiset.Distiset): | |
| 90 | 
             
                        "size_categories": size_categories_parser(
         | 
| 91 | 
             
                            max(len(dataset) for dataset in self.values())
         | 
| 92 | 
             
                        ),
         | 
|  | |
| 93 | 
             
                        "tags": [
         | 
| 94 | 
             
                            "synthetic",
         | 
| 95 | 
             
                            "distilabel",
         | 
|  | |
| 81 | 
             
                            dataset[0] if not isinstance(dataset, dict) else dataset["train"][0]
         | 
| 82 | 
             
                        )
         | 
| 83 |  | 
| 84 | 
            +
                    keys = list(sample_records.keys())
         | 
| 85 | 
            +
                    if len(keys) != 2 or not (
         | 
| 86 | 
            +
                        ("label" in keys and "text" in keys)
         | 
| 87 | 
            +
                        or ("labels" in keys and "text" in keys)
         | 
| 88 | 
            +
                    ):
         | 
| 89 | 
            +
                        task_categories = ["text-classification"]
         | 
| 90 | 
            +
                    elif "prompt" in keys or "messages" in keys:
         | 
| 91 | 
            +
                        task_categories = ["text-generation", "text2text-generation"]
         | 
| 92 | 
            +
             | 
| 93 | 
             
                    readme_metadata = {}
         | 
| 94 | 
             
                    if repo_id and token:
         | 
| 95 | 
             
                        readme_metadata = self._extract_readme_metadata(repo_id, token)
         | 
|  | |
| 99 | 
             
                        "size_categories": size_categories_parser(
         | 
| 100 | 
             
                            max(len(dataset) for dataset in self.values())
         | 
| 101 | 
             
                        ),
         | 
| 102 | 
            +
                        "task_categories": task_categories,
         | 
| 103 | 
             
                        "tags": [
         | 
| 104 | 
             
                            "synthetic",
         | 
| 105 | 
             
                            "distilabel",
         | 
    	
        src/synthetic_dataset_generator/apps/base.py
    CHANGED
    
    | @@ -77,10 +77,15 @@ def validate_push_to_hub(org_name, repo_name): | |
| 77 | 
             
                return repo_id
         | 
| 78 |  | 
| 79 |  | 
| 80 | 
            -
            def combine_datasets( | 
|  | |
|  | |
| 81 | 
             
                try:
         | 
| 82 | 
             
                    new_dataset = load_dataset(
         | 
| 83 | 
            -
                        repo_id, | 
|  | |
|  | |
|  | |
| 84 | 
             
                    )
         | 
| 85 | 
             
                    return concatenate_datasets([dataset, new_dataset])
         | 
| 86 | 
             
                except Exception:
         | 
|  | |
| 77 | 
             
                return repo_id
         | 
| 78 |  | 
| 79 |  | 
| 80 | 
            +
            def combine_datasets(
         | 
| 81 | 
            +
                repo_id: str, dataset: Dataset, oauth_token: Union[OAuthToken, None]
         | 
| 82 | 
            +
            ) -> Dataset:
         | 
| 83 | 
             
                try:
         | 
| 84 | 
             
                    new_dataset = load_dataset(
         | 
| 85 | 
            +
                        repo_id,
         | 
| 86 | 
            +
                        split="train",
         | 
| 87 | 
            +
                        download_mode="force_redownload",
         | 
| 88 | 
            +
                        token=oauth_token.token,
         | 
| 89 | 
             
                    )
         | 
| 90 | 
             
                    return concatenate_datasets([dataset, new_dataset])
         | 
| 91 | 
             
                except Exception:
         | 
    	
        src/synthetic_dataset_generator/apps/chat.py
    CHANGED
    
    | @@ -25,12 +25,12 @@ from synthetic_dataset_generator.constants import ( | |
| 25 | 
             
                MODEL,
         | 
| 26 | 
             
                SFT_AVAILABLE,
         | 
| 27 | 
             
            )
         | 
|  | |
| 28 | 
             
            from synthetic_dataset_generator.pipelines.chat import (
         | 
| 29 | 
             
                DEFAULT_DATASET_DESCRIPTIONS,
         | 
| 30 | 
             
                generate_pipeline_code,
         | 
| 31 | 
             
                get_magpie_generator,
         | 
| 32 | 
             
                get_prompt_generator,
         | 
| 33 | 
            -
                get_prompt_rewriter,
         | 
| 34 | 
             
                get_response_generator,
         | 
| 35 | 
             
            )
         | 
| 36 | 
             
            from synthetic_dataset_generator.pipelines.embeddings import (
         | 
| @@ -40,6 +40,7 @@ from synthetic_dataset_generator.pipelines.embeddings import ( | |
| 40 | 
             
            from synthetic_dataset_generator.utils import (
         | 
| 41 | 
             
                get_argilla_client,
         | 
| 42 | 
             
                get_org_dropdown,
         | 
|  | |
| 43 | 
             
                swap_visibility,
         | 
| 44 | 
             
            )
         | 
| 45 |  | 
| @@ -106,7 +107,6 @@ def generate_dataset( | |
| 106 | 
             
            ) -> pd.DataFrame:
         | 
| 107 | 
             
                num_rows = test_max_num_rows(num_rows)
         | 
| 108 | 
             
                progress(0.0, desc="(1/2) Generating instructions")
         | 
| 109 | 
            -
                prompt_rewriter = get_prompt_rewriter()
         | 
| 110 | 
             
                magpie_generator = get_magpie_generator(
         | 
| 111 | 
             
                    system_prompt, num_turns, temperature, is_sample
         | 
| 112 | 
             
                )
         | 
| @@ -117,14 +117,7 @@ def generate_dataset( | |
| 117 | 
             
                batch_size = DEFAULT_BATCH_SIZE
         | 
| 118 |  | 
| 119 | 
             
                # create prompt rewrites
         | 
| 120 | 
            -
                 | 
| 121 | 
            -
                    {
         | 
| 122 | 
            -
                        "instruction": f"Rewrite this prompt keeping the same structure but highlighting different aspects of the original without adding anything new. Original prompt: {system_prompt} Rewritten prompt: "
         | 
| 123 | 
            -
                    }
         | 
| 124 | 
            -
                    for i in range(int(num_rows / 50))
         | 
| 125 | 
            -
                ]
         | 
| 126 | 
            -
                batch = list(prompt_rewriter.process(inputs=inputs))
         | 
| 127 | 
            -
                prompt_rewrites = [entry["generation"] for entry in batch[0]] + [system_prompt]
         | 
| 128 |  | 
| 129 | 
             
                # create instructions
         | 
| 130 | 
             
                n_processed = 0
         | 
| @@ -142,6 +135,7 @@ def generate_dataset( | |
| 142 | 
             
                    batch = list(magpie_generator.process(inputs=inputs))
         | 
| 143 | 
             
                    magpie_results.extend(batch[0])
         | 
| 144 | 
             
                    n_processed += batch_size
         | 
|  | |
| 145 | 
             
                progress(0.5, desc="(1/2) Generating instructions")
         | 
| 146 |  | 
| 147 | 
             
                # generate responses
         | 
| @@ -158,6 +152,7 @@ def generate_dataset( | |
| 158 | 
             
                        responses = list(response_generator.process(inputs=batch))
         | 
| 159 | 
             
                        response_results.extend(responses[0])
         | 
| 160 | 
             
                        n_processed += batch_size
         | 
|  | |
| 161 | 
             
                    for result in response_results:
         | 
| 162 | 
             
                        result["prompt"] = result["instruction"]
         | 
| 163 | 
             
                        result["completion"] = result["generation"]
         | 
| @@ -178,6 +173,7 @@ def generate_dataset( | |
| 178 | 
             
                        responses = list(response_generator.process(inputs=batch))
         | 
| 179 | 
             
                        response_results.extend(responses[0])
         | 
| 180 | 
             
                        n_processed += batch_size
         | 
|  | |
| 181 | 
             
                    for result in response_results:
         | 
| 182 | 
             
                        result["messages"].append(
         | 
| 183 | 
             
                            {"role": "assistant", "content": result["generation"]}
         | 
| @@ -236,7 +232,7 @@ def push_dataset_to_hub( | |
| 236 | 
             
                dataframe = convert_dataframe_messages(dataframe)
         | 
| 237 | 
             
                progress(0.7, desc="Creating dataset")
         | 
| 238 | 
             
                dataset = Dataset.from_pandas(dataframe)
         | 
| 239 | 
            -
                dataset = combine_datasets(repo_id, dataset)
         | 
| 240 | 
             
                progress(0.9, desc="Pushing dataset")
         | 
| 241 | 
             
                distiset = Distiset({"default": dataset})
         | 
| 242 | 
             
                distiset.push_to_hub(
         | 
| @@ -600,4 +596,5 @@ with gr.Blocks() as app: | |
| 600 | 
             
                            outputs=[dataset_description, system_prompt, num_turns, dataframe],
         | 
| 601 | 
             
                        )
         | 
| 602 | 
             
                        app.load(fn=get_org_dropdown, outputs=[org_name])
         | 
|  | |
| 603 | 
             
                    app.load(fn=swap_visibility, outputs=main_ui)
         | 
|  | |
| 25 | 
             
                MODEL,
         | 
| 26 | 
             
                SFT_AVAILABLE,
         | 
| 27 | 
             
            )
         | 
| 28 | 
            +
            from synthetic_dataset_generator.pipelines.base import get_rewriten_prompts
         | 
| 29 | 
             
            from synthetic_dataset_generator.pipelines.chat import (
         | 
| 30 | 
             
                DEFAULT_DATASET_DESCRIPTIONS,
         | 
| 31 | 
             
                generate_pipeline_code,
         | 
| 32 | 
             
                get_magpie_generator,
         | 
| 33 | 
             
                get_prompt_generator,
         | 
|  | |
| 34 | 
             
                get_response_generator,
         | 
| 35 | 
             
            )
         | 
| 36 | 
             
            from synthetic_dataset_generator.pipelines.embeddings import (
         | 
|  | |
| 40 | 
             
            from synthetic_dataset_generator.utils import (
         | 
| 41 | 
             
                get_argilla_client,
         | 
| 42 | 
             
                get_org_dropdown,
         | 
| 43 | 
            +
                get_random_repo_name,
         | 
| 44 | 
             
                swap_visibility,
         | 
| 45 | 
             
            )
         | 
| 46 |  | 
|  | |
| 107 | 
             
            ) -> pd.DataFrame:
         | 
| 108 | 
             
                num_rows = test_max_num_rows(num_rows)
         | 
| 109 | 
             
                progress(0.0, desc="(1/2) Generating instructions")
         | 
|  | |
| 110 | 
             
                magpie_generator = get_magpie_generator(
         | 
| 111 | 
             
                    system_prompt, num_turns, temperature, is_sample
         | 
| 112 | 
             
                )
         | 
|  | |
| 117 | 
             
                batch_size = DEFAULT_BATCH_SIZE
         | 
| 118 |  | 
| 119 | 
             
                # create prompt rewrites
         | 
| 120 | 
            +
                prompt_rewrites = get_rewriten_prompts(system_prompt, num_rows)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 121 |  | 
| 122 | 
             
                # create instructions
         | 
| 123 | 
             
                n_processed = 0
         | 
|  | |
| 135 | 
             
                    batch = list(magpie_generator.process(inputs=inputs))
         | 
| 136 | 
             
                    magpie_results.extend(batch[0])
         | 
| 137 | 
             
                    n_processed += batch_size
         | 
| 138 | 
            +
                    random.seed(a=random.randint(0, 2**32 - 1))
         | 
| 139 | 
             
                progress(0.5, desc="(1/2) Generating instructions")
         | 
| 140 |  | 
| 141 | 
             
                # generate responses
         | 
|  | |
| 152 | 
             
                        responses = list(response_generator.process(inputs=batch))
         | 
| 153 | 
             
                        response_results.extend(responses[0])
         | 
| 154 | 
             
                        n_processed += batch_size
         | 
| 155 | 
            +
                        random.seed(a=random.randint(0, 2**32 - 1))
         | 
| 156 | 
             
                    for result in response_results:
         | 
| 157 | 
             
                        result["prompt"] = result["instruction"]
         | 
| 158 | 
             
                        result["completion"] = result["generation"]
         | 
|  | |
| 173 | 
             
                        responses = list(response_generator.process(inputs=batch))
         | 
| 174 | 
             
                        response_results.extend(responses[0])
         | 
| 175 | 
             
                        n_processed += batch_size
         | 
| 176 | 
            +
                        random.seed(a=random.randint(0, 2**32 - 1))
         | 
| 177 | 
             
                    for result in response_results:
         | 
| 178 | 
             
                        result["messages"].append(
         | 
| 179 | 
             
                            {"role": "assistant", "content": result["generation"]}
         | 
|  | |
| 232 | 
             
                dataframe = convert_dataframe_messages(dataframe)
         | 
| 233 | 
             
                progress(0.7, desc="Creating dataset")
         | 
| 234 | 
             
                dataset = Dataset.from_pandas(dataframe)
         | 
| 235 | 
            +
                dataset = combine_datasets(repo_id, dataset, oauth_token)
         | 
| 236 | 
             
                progress(0.9, desc="Pushing dataset")
         | 
| 237 | 
             
                distiset = Distiset({"default": dataset})
         | 
| 238 | 
             
                distiset.push_to_hub(
         | 
|  | |
| 596 | 
             
                            outputs=[dataset_description, system_prompt, num_turns, dataframe],
         | 
| 597 | 
             
                        )
         | 
| 598 | 
             
                        app.load(fn=get_org_dropdown, outputs=[org_name])
         | 
| 599 | 
            +
                    app.load(fn=get_random_repo_name, outputs=[repo_name])
         | 
| 600 | 
             
                    app.load(fn=swap_visibility, outputs=main_ui)
         | 
    	
        src/synthetic_dataset_generator/apps/eval.py
    CHANGED
    
    | @@ -41,6 +41,7 @@ from synthetic_dataset_generator.utils import ( | |
| 41 | 
             
                extract_column_names,
         | 
| 42 | 
             
                get_argilla_client,
         | 
| 43 | 
             
                get_org_dropdown,
         | 
|  | |
| 44 | 
             
                pad_or_truncate_list,
         | 
| 45 | 
             
                process_columns,
         | 
| 46 | 
             
                swap_visibility,
         | 
| @@ -359,7 +360,7 @@ def push_dataset_to_hub( | |
| 359 | 
             
            ):
         | 
| 360 | 
             
                repo_id = validate_push_to_hub(org_name, repo_name)
         | 
| 361 | 
             
                dataset = Dataset.from_pandas(dataframe)
         | 
| 362 | 
            -
                dataset = combine_datasets(repo_id, dataset)
         | 
| 363 | 
             
                distiset = Distiset({"default": dataset})
         | 
| 364 | 
             
                distiset.push_to_hub(
         | 
| 365 | 
             
                    repo_id=repo_id,
         | 
| @@ -907,3 +908,4 @@ with gr.Blocks() as app: | |
| 907 |  | 
| 908 | 
             
                app.load(fn=swap_visibility, outputs=main_ui)
         | 
| 909 | 
             
                app.load(fn=get_org_dropdown, outputs=[org_name])
         | 
|  | 
|  | |
| 41 | 
             
                extract_column_names,
         | 
| 42 | 
             
                get_argilla_client,
         | 
| 43 | 
             
                get_org_dropdown,
         | 
| 44 | 
            +
                get_random_repo_name,
         | 
| 45 | 
             
                pad_or_truncate_list,
         | 
| 46 | 
             
                process_columns,
         | 
| 47 | 
             
                swap_visibility,
         | 
|  | |
| 360 | 
             
            ):
         | 
| 361 | 
             
                repo_id = validate_push_to_hub(org_name, repo_name)
         | 
| 362 | 
             
                dataset = Dataset.from_pandas(dataframe)
         | 
| 363 | 
            +
                dataset = combine_datasets(repo_id, dataset, oauth_token)
         | 
| 364 | 
             
                distiset = Distiset({"default": dataset})
         | 
| 365 | 
             
                distiset.push_to_hub(
         | 
| 366 | 
             
                    repo_id=repo_id,
         | 
|  | |
| 908 |  | 
| 909 | 
             
                app.load(fn=swap_visibility, outputs=main_ui)
         | 
| 910 | 
             
                app.load(fn=get_org_dropdown, outputs=[org_name])
         | 
| 911 | 
            +
                app.load(fn=get_random_repo_name, outputs=[repo_name])
         | 
    	
        src/synthetic_dataset_generator/apps/textcat.py
    CHANGED
    
    | @@ -20,6 +20,7 @@ from synthetic_dataset_generator.apps.base import ( | |
| 20 | 
             
                validate_push_to_hub,
         | 
| 21 | 
             
            )
         | 
| 22 | 
             
            from synthetic_dataset_generator.constants import DEFAULT_BATCH_SIZE
         | 
|  | |
| 23 | 
             
            from synthetic_dataset_generator.pipelines.embeddings import (
         | 
| 24 | 
             
                get_embeddings,
         | 
| 25 | 
             
                get_sentence_embedding_dimensions,
         | 
| @@ -35,6 +36,7 @@ from synthetic_dataset_generator.utils import ( | |
| 35 | 
             
                get_argilla_client,
         | 
| 36 | 
             
                get_org_dropdown,
         | 
| 37 | 
             
                get_preprocess_labels,
         | 
|  | |
| 38 | 
             
                swap_visibility,
         | 
| 39 | 
             
            )
         | 
| 40 |  | 
| @@ -106,7 +108,7 @@ def generate_dataset( | |
| 106 | 
             
                )
         | 
| 107 | 
             
                updated_system_prompt = f"{system_prompt}. Optional labels: {', '.join(labels)}."
         | 
| 108 | 
             
                if multi_label:
         | 
| 109 | 
            -
                    updated_system_prompt = f"{updated_system_prompt}. Only apply relevant labels. Applying less labels is better than applying too many labels."
         | 
| 110 | 
             
                labeller_generator = get_labeller_generator(
         | 
| 111 | 
             
                    system_prompt=updated_system_prompt,
         | 
| 112 | 
             
                    labels=labels,
         | 
| @@ -118,6 +120,7 @@ def generate_dataset( | |
| 118 | 
             
                # create text classification data
         | 
| 119 | 
             
                n_processed = 0
         | 
| 120 | 
             
                textcat_results = []
         | 
|  | |
| 121 | 
             
                while n_processed < num_rows:
         | 
| 122 | 
             
                    progress(
         | 
| 123 | 
             
                        2 * 0.5 * n_processed / num_rows,
         | 
| @@ -128,25 +131,24 @@ def generate_dataset( | |
| 128 | 
             
                    batch_size = min(batch_size, remaining_rows)
         | 
| 129 | 
             
                    inputs = []
         | 
| 130 | 
             
                    for _ in range(batch_size):
         | 
|  | |
| 131 | 
             
                        if multi_label:
         | 
| 132 | 
             
                            num_labels = len(labels)
         | 
| 133 | 
             
                            k = int(
         | 
| 134 | 
             
                                random.betavariate(alpha=(num_labels - 1), beta=num_labels)
         | 
| 135 | 
             
                                * num_labels
         | 
| 136 | 
             
                            )
         | 
| 137 | 
            -
                        else:
         | 
| 138 | 
            -
                            k = 1
         | 
| 139 | 
            -
             | 
| 140 | 
             
                        sampled_labels = random.sample(labels, min(k, len(labels)))
         | 
| 141 | 
             
                        random.shuffle(sampled_labels)
         | 
| 142 | 
             
                        inputs.append(
         | 
| 143 | 
             
                            {
         | 
| 144 | 
            -
                                "task": f"{ | 
| 145 | 
             
                            }
         | 
| 146 | 
             
                        )
         | 
| 147 | 
             
                    batch = list(textcat_generator.process(inputs=inputs))
         | 
| 148 | 
             
                    textcat_results.extend(batch[0])
         | 
| 149 | 
             
                    n_processed += batch_size
         | 
|  | |
| 150 | 
             
                for result in textcat_results:
         | 
| 151 | 
             
                    result["text"] = result["input_text"]
         | 
| 152 |  | 
| @@ -164,6 +166,7 @@ def generate_dataset( | |
| 164 | 
             
                    labels_batch = list(labeller_generator.process(inputs=batch))
         | 
| 165 | 
             
                    labeller_results.extend(labels_batch[0])
         | 
| 166 | 
             
                    n_processed += batch_size
         | 
|  | |
| 167 | 
             
                progress(
         | 
| 168 | 
             
                    1,
         | 
| 169 | 
             
                    total=total_steps,
         | 
| @@ -178,26 +181,43 @@ def generate_dataset( | |
| 178 |  | 
| 179 | 
             
                dataframe = pd.DataFrame(distiset_results)
         | 
| 180 | 
             
                if multi_label:
         | 
| 181 | 
            -
             | 
| 182 | 
            -
             | 
| 183 | 
            -
             | 
| 184 | 
            -
             | 
|  | |
|  | |
|  | |
| 185 | 
             
                                    label.lower().strip()
         | 
| 186 | 
            -
                                    if (label is not None and label.lower().strip() in labels)
         | 
| 187 | 
            -
                                    else random.choice(labels)
         | 
| 188 | 
             
                                    for label in x
         | 
| 189 | 
            -
             | 
|  | |
| 190 | 
             
                            )
         | 
| 191 | 
            -
                         | 
| 192 | 
            -
             | 
|  | |
|  | |
| 193 | 
             
                    dataframe = dataframe[dataframe["labels"].notna()]
         | 
| 194 | 
             
                else:
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 195 | 
             
                    dataframe = dataframe.rename(columns={"labels": "label"})
         | 
| 196 | 
            -
                    dataframe["label"] = dataframe["label"].apply(
         | 
| 197 | 
            -
                        lambda x: x.lower().strip()
         | 
| 198 | 
            -
                        if x and x.lower().strip() in labels
         | 
| 199 | 
            -
                        else random.choice(labels)
         | 
| 200 | 
            -
                    )
         | 
| 201 | 
             
                dataframe = dataframe[dataframe["text"].notna()]
         | 
| 202 |  | 
| 203 | 
             
                progress(1.0, desc="Dataset created")
         | 
| @@ -235,7 +255,7 @@ def push_dataset_to_hub( | |
| 235 | 
             
                    dataframe.reset_index(drop=True),
         | 
| 236 | 
             
                    features=features,
         | 
| 237 | 
             
                )
         | 
| 238 | 
            -
                dataset = combine_datasets(repo_id, dataset)
         | 
| 239 | 
             
                distiset = Distiset({"default": dataset})
         | 
| 240 | 
             
                progress(0.9, desc="Pushing dataset")
         | 
| 241 | 
             
                distiset.push_to_hub(
         | 
| @@ -647,3 +667,4 @@ with gr.Blocks() as app: | |
| 647 |  | 
| 648 | 
             
                app.load(fn=swap_visibility, outputs=main_ui)
         | 
| 649 | 
             
                app.load(fn=get_org_dropdown, outputs=[org_name])
         | 
|  | 
|  | |
| 20 | 
             
                validate_push_to_hub,
         | 
| 21 | 
             
            )
         | 
| 22 | 
             
            from synthetic_dataset_generator.constants import DEFAULT_BATCH_SIZE
         | 
| 23 | 
            +
            from synthetic_dataset_generator.pipelines.base import get_rewriten_prompts
         | 
| 24 | 
             
            from synthetic_dataset_generator.pipelines.embeddings import (
         | 
| 25 | 
             
                get_embeddings,
         | 
| 26 | 
             
                get_sentence_embedding_dimensions,
         | 
|  | |
| 36 | 
             
                get_argilla_client,
         | 
| 37 | 
             
                get_org_dropdown,
         | 
| 38 | 
             
                get_preprocess_labels,
         | 
| 39 | 
            +
                get_random_repo_name,
         | 
| 40 | 
             
                swap_visibility,
         | 
| 41 | 
             
            )
         | 
| 42 |  | 
|  | |
| 108 | 
             
                )
         | 
| 109 | 
             
                updated_system_prompt = f"{system_prompt}. Optional labels: {', '.join(labels)}."
         | 
| 110 | 
             
                if multi_label:
         | 
| 111 | 
            +
                    updated_system_prompt = f"{updated_system_prompt}. Only apply relevant labels. Applying less labels is always better than applying too many labels."
         | 
| 112 | 
             
                labeller_generator = get_labeller_generator(
         | 
| 113 | 
             
                    system_prompt=updated_system_prompt,
         | 
| 114 | 
             
                    labels=labels,
         | 
|  | |
| 120 | 
             
                # create text classification data
         | 
| 121 | 
             
                n_processed = 0
         | 
| 122 | 
             
                textcat_results = []
         | 
| 123 | 
            +
                rewritten_system_prompts = get_rewriten_prompts(system_prompt, num_rows)
         | 
| 124 | 
             
                while n_processed < num_rows:
         | 
| 125 | 
             
                    progress(
         | 
| 126 | 
             
                        2 * 0.5 * n_processed / num_rows,
         | 
|  | |
| 131 | 
             
                    batch_size = min(batch_size, remaining_rows)
         | 
| 132 | 
             
                    inputs = []
         | 
| 133 | 
             
                    for _ in range(batch_size):
         | 
| 134 | 
            +
                        k = 1
         | 
| 135 | 
             
                        if multi_label:
         | 
| 136 | 
             
                            num_labels = len(labels)
         | 
| 137 | 
             
                            k = int(
         | 
| 138 | 
             
                                random.betavariate(alpha=(num_labels - 1), beta=num_labels)
         | 
| 139 | 
             
                                * num_labels
         | 
| 140 | 
             
                            )
         | 
|  | |
|  | |
|  | |
| 141 | 
             
                        sampled_labels = random.sample(labels, min(k, len(labels)))
         | 
| 142 | 
             
                        random.shuffle(sampled_labels)
         | 
| 143 | 
             
                        inputs.append(
         | 
| 144 | 
             
                            {
         | 
| 145 | 
            +
                                "task": f"{random.choice(rewritten_system_prompts)}. The text represents the following categories: {', '.join(sampled_labels)}"
         | 
| 146 | 
             
                            }
         | 
| 147 | 
             
                        )
         | 
| 148 | 
             
                    batch = list(textcat_generator.process(inputs=inputs))
         | 
| 149 | 
             
                    textcat_results.extend(batch[0])
         | 
| 150 | 
             
                    n_processed += batch_size
         | 
| 151 | 
            +
                    random.seed(a=random.randint(0, 2**32 - 1))
         | 
| 152 | 
             
                for result in textcat_results:
         | 
| 153 | 
             
                    result["text"] = result["input_text"]
         | 
| 154 |  | 
|  | |
| 166 | 
             
                    labels_batch = list(labeller_generator.process(inputs=batch))
         | 
| 167 | 
             
                    labeller_results.extend(labels_batch[0])
         | 
| 168 | 
             
                    n_processed += batch_size
         | 
| 169 | 
            +
                    random.seed(a=random.randint(0, 2**32 - 1))
         | 
| 170 | 
             
                progress(
         | 
| 171 | 
             
                    1,
         | 
| 172 | 
             
                    total=total_steps,
         | 
|  | |
| 181 |  | 
| 182 | 
             
                dataframe = pd.DataFrame(distiset_results)
         | 
| 183 | 
             
                if multi_label:
         | 
| 184 | 
            +
             | 
| 185 | 
            +
                    def _validate_labels(x):
         | 
| 186 | 
            +
                        if isinstance(x, str):  # single label
         | 
| 187 | 
            +
                            return [x.lower().strip()]
         | 
| 188 | 
            +
                        elif isinstance(x, list):  # multiple labels
         | 
| 189 | 
            +
                            return list(
         | 
| 190 | 
            +
                                set(
         | 
| 191 | 
             
                                    label.lower().strip()
         | 
|  | |
|  | |
| 192 | 
             
                                    for label in x
         | 
| 193 | 
            +
                                    if label.lower().strip() in labels
         | 
| 194 | 
            +
                                )
         | 
| 195 | 
             
                            )
         | 
| 196 | 
            +
                        else:
         | 
| 197 | 
            +
                            return list(set([random.choice(labels)]))
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                    dataframe["labels"] = dataframe["labels"].apply(_validate_labels)
         | 
| 200 | 
             
                    dataframe = dataframe[dataframe["labels"].notna()]
         | 
| 201 | 
             
                else:
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                    def _validate_labels(x):
         | 
| 204 | 
            +
                        if isinstance(x, str) and x.lower().strip() in labels:
         | 
| 205 | 
            +
                            return x.lower().strip()
         | 
| 206 | 
            +
                        elif isinstance(x, list):
         | 
| 207 | 
            +
                            options = [
         | 
| 208 | 
            +
                                label.lower().strip()
         | 
| 209 | 
            +
                                for label in x
         | 
| 210 | 
            +
                                if isinstance(label, str) and label.lower().strip() in labels
         | 
| 211 | 
            +
                            ]
         | 
| 212 | 
            +
                            if options:
         | 
| 213 | 
            +
                                return random.choice(options)
         | 
| 214 | 
            +
                            else:
         | 
| 215 | 
            +
                                return random.choice(labels)
         | 
| 216 | 
            +
                        else:
         | 
| 217 | 
            +
                            return random.choice(labels)
         | 
| 218 | 
            +
             | 
| 219 | 
             
                    dataframe = dataframe.rename(columns={"labels": "label"})
         | 
| 220 | 
            +
                    dataframe["label"] = dataframe["label"].apply(_validate_labels)
         | 
|  | |
|  | |
|  | |
|  | |
| 221 | 
             
                dataframe = dataframe[dataframe["text"].notna()]
         | 
| 222 |  | 
| 223 | 
             
                progress(1.0, desc="Dataset created")
         | 
|  | |
| 255 | 
             
                    dataframe.reset_index(drop=True),
         | 
| 256 | 
             
                    features=features,
         | 
| 257 | 
             
                )
         | 
| 258 | 
            +
                dataset = combine_datasets(repo_id, dataset, oauth_token)
         | 
| 259 | 
             
                distiset = Distiset({"default": dataset})
         | 
| 260 | 
             
                progress(0.9, desc="Pushing dataset")
         | 
| 261 | 
             
                distiset.push_to_hub(
         | 
|  | |
| 667 |  | 
| 668 | 
             
                app.load(fn=swap_visibility, outputs=main_ui)
         | 
| 669 | 
             
                app.load(fn=get_org_dropdown, outputs=[org_name])
         | 
| 670 | 
            +
                app.load(fn=get_random_repo_name, outputs=[repo_name])
         | 
    	
        src/synthetic_dataset_generator/constants.py
    CHANGED
    
    | @@ -7,39 +7,66 @@ import argilla as rg | |
| 7 | 
             
            TEXTCAT_TASK = "text_classification"
         | 
| 8 | 
             
            SFT_TASK = "supervised_fine_tuning"
         | 
| 9 |  | 
| 10 | 
            -
            #  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 11 | 
             
            HF_TOKEN = os.getenv("HF_TOKEN")
         | 
| 12 | 
             
            if not HF_TOKEN:
         | 
| 13 | 
             
                raise ValueError(
         | 
| 14 | 
             
                    "HF_TOKEN is not set. Ensure you have set the HF_TOKEN environment variable that has access to the Hugging Face Hub repositories and Inference Endpoints."
         | 
| 15 | 
             
                )
         | 
| 16 |  | 
| 17 | 
            -
            # Inference
         | 
| 18 | 
            -
            MAX_NUM_TOKENS = int(os.getenv("MAX_NUM_TOKENS", 2048))
         | 
| 19 | 
            -
            MAX_NUM_ROWS: str | int = int(os.getenv("MAX_NUM_ROWS", 1000))
         | 
| 20 | 
            -
            DEFAULT_BATCH_SIZE = int(os.getenv("DEFAULT_BATCH_SIZE", 5))
         | 
| 21 | 
            -
            MODEL = os.getenv("MODEL", "meta-llama/Meta-Llama-3.1-8B-Instruct")
         | 
| 22 | 
            -
            BASE_URL = os.getenv("BASE_URL", default=None)
         | 
| 23 | 
            -
             | 
| 24 | 
             
            _API_KEY = os.getenv("API_KEY")
         | 
| 25 | 
            -
             | 
| 26 | 
            -
                 | 
| 27 | 
            -
             | 
| 28 | 
            -
                 | 
| 29 | 
            -
             | 
| 30 | 
            -
                ]
         | 
| 31 | 
             
            API_KEYS = [token for token in API_KEYS if token]
         | 
| 32 |  | 
| 33 | 
             
            # Determine if SFT is available
         | 
| 34 | 
             
            SFT_AVAILABLE = False
         | 
| 35 | 
             
            llama_options = ["llama3", "llama-3", "llama 3"]
         | 
| 36 | 
             
            qwen_options = ["qwen2", "qwen-2", "qwen 2"]
         | 
| 37 | 
            -
             | 
|  | |
| 38 | 
             
                SFT_AVAILABLE = True
         | 
| 39 | 
            -
                passed_pre_query_template  | 
| 40 | 
            -
                if passed_pre_query_template.lower() in llama_options:
         | 
| 41 | 
             
                    MAGPIE_PRE_QUERY_TEMPLATE = "llama3"
         | 
| 42 | 
            -
                elif passed_pre_query_template | 
| 43 | 
             
                    MAGPIE_PRE_QUERY_TEMPLATE = "qwen2"
         | 
| 44 | 
             
                else:
         | 
| 45 | 
             
                    MAGPIE_PRE_QUERY_TEMPLATE = passed_pre_query_template
         | 
| @@ -54,12 +81,12 @@ elif MODEL.lower() in qwen_options or any( | |
| 54 | 
             
                SFT_AVAILABLE = True
         | 
| 55 | 
             
                MAGPIE_PRE_QUERY_TEMPLATE = "qwen2"
         | 
| 56 |  | 
| 57 | 
            -
            if  | 
| 58 | 
             
                SFT_AVAILABLE = False
         | 
| 59 |  | 
| 60 | 
             
            if not SFT_AVAILABLE:
         | 
| 61 | 
             
                warnings.warn(
         | 
| 62 | 
            -
                     | 
| 63 | 
             
                )
         | 
| 64 | 
             
                MAGPIE_PRE_QUERY_TEMPLATE = None
         | 
| 65 |  | 
| @@ -67,11 +94,12 @@ if not SFT_AVAILABLE: | |
| 67 | 
             
            STATIC_EMBEDDING_MODEL = "minishlab/potion-base-8M"
         | 
| 68 |  | 
| 69 | 
             
            # Argilla
         | 
| 70 | 
            -
            ARGILLA_API_URL = os.getenv("ARGILLA_API_URL")
         | 
| 71 | 
            -
             | 
| 72 | 
            -
             | 
| 73 | 
            -
             | 
| 74 | 
            -
                 | 
|  | |
| 75 |  | 
| 76 | 
             
            if not ARGILLA_API_URL or not ARGILLA_API_KEY:
         | 
| 77 | 
             
                warnings.warn("ARGILLA_API_URL or ARGILLA_API_KEY is not set or is empty")
         | 
|  | |
| 7 | 
             
            TEXTCAT_TASK = "text_classification"
         | 
| 8 | 
             
            SFT_TASK = "supervised_fine_tuning"
         | 
| 9 |  | 
| 10 | 
            +
            # Inference
         | 
| 11 | 
            +
            MAX_NUM_TOKENS = int(os.getenv("MAX_NUM_TOKENS", 2048))
         | 
| 12 | 
            +
            MAX_NUM_ROWS = int(os.getenv("MAX_NUM_ROWS", 1000))
         | 
| 13 | 
            +
            DEFAULT_BATCH_SIZE = int(os.getenv("DEFAULT_BATCH_SIZE", 5))
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            # Models
         | 
| 16 | 
            +
            MODEL = os.getenv("MODEL", "meta-llama/Meta-Llama-3.1-8B-Instruct")
         | 
| 17 | 
            +
            TOKENIZER_ID = os.getenv(key="TOKENIZER_ID", default=None)
         | 
| 18 | 
            +
            OPENAI_BASE_URL = os.getenv("OPENAI_BASE_URL")
         | 
| 19 | 
            +
            OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL")
         | 
| 20 | 
            +
            HUGGINGFACE_BASE_URL = os.getenv("HUGGINGFACE_BASE_URL")
         | 
| 21 | 
            +
            VLLM_BASE_URL = os.getenv("VLLM_BASE_URL")
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            # check if model is set correctly
         | 
| 24 | 
            +
            if HUGGINGFACE_BASE_URL and MODEL:
         | 
| 25 | 
            +
                raise ValueError(
         | 
| 26 | 
            +
                    "`HUGGINGFACE_BASE_URL` and `MODEL` cannot be set at the same time. Use a model id for serverless inference and a base URL dedicated to Hugging Face Inference Endpoints."
         | 
| 27 | 
            +
                )
         | 
| 28 | 
            +
            if not MODEL:
         | 
| 29 | 
            +
                if OPENAI_BASE_URL or OLLAMA_BASE_URL or VLLM_BASE_URL:
         | 
| 30 | 
            +
                    raise ValueError("`MODEL` is not set. Please provide a model id for inference.")
         | 
| 31 | 
            +
             | 
| 32 | 
            +
            # Check if multiple base URLs are provided
         | 
| 33 | 
            +
            base_urls = [
         | 
| 34 | 
            +
                url
         | 
| 35 | 
            +
                for url in [OPENAI_BASE_URL, OLLAMA_BASE_URL, HUGGINGFACE_BASE_URL, VLLM_BASE_URL]
         | 
| 36 | 
            +
                if url
         | 
| 37 | 
            +
            ]
         | 
| 38 | 
            +
            if len(base_urls) > 1:
         | 
| 39 | 
            +
                raise ValueError(
         | 
| 40 | 
            +
                    f"Multiple base URLs provided: {', '.join(base_urls)}. Only one base URL can be set at a time."
         | 
| 41 | 
            +
                )
         | 
| 42 | 
            +
            BASE_URL = OPENAI_BASE_URL or OLLAMA_BASE_URL or HUGGINGFACE_BASE_URL or VLLM_BASE_URL
         | 
| 43 | 
            +
             | 
| 44 | 
            +
             | 
| 45 | 
            +
            # API Keys
         | 
| 46 | 
             
            HF_TOKEN = os.getenv("HF_TOKEN")
         | 
| 47 | 
             
            if not HF_TOKEN:
         | 
| 48 | 
             
                raise ValueError(
         | 
| 49 | 
             
                    "HF_TOKEN is not set. Ensure you have set the HF_TOKEN environment variable that has access to the Hugging Face Hub repositories and Inference Endpoints."
         | 
| 50 | 
             
                )
         | 
| 51 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 52 | 
             
            _API_KEY = os.getenv("API_KEY")
         | 
| 53 | 
            +
            API_KEYS = (
         | 
| 54 | 
            +
                [_API_KEY]
         | 
| 55 | 
            +
                if _API_KEY
         | 
| 56 | 
            +
                else [HF_TOKEN] + [os.getenv(f"HF_TOKEN_{i}") for i in range(1, 10)]
         | 
| 57 | 
            +
            )
         | 
|  | |
| 58 | 
             
            API_KEYS = [token for token in API_KEYS if token]
         | 
| 59 |  | 
| 60 | 
             
            # Determine if SFT is available
         | 
| 61 | 
             
            SFT_AVAILABLE = False
         | 
| 62 | 
             
            llama_options = ["llama3", "llama-3", "llama 3"]
         | 
| 63 | 
             
            qwen_options = ["qwen2", "qwen-2", "qwen 2"]
         | 
| 64 | 
            +
             | 
| 65 | 
            +
            if passed_pre_query_template := os.getenv("MAGPIE_PRE_QUERY_TEMPLATE", "").lower():
         | 
| 66 | 
             
                SFT_AVAILABLE = True
         | 
| 67 | 
            +
                if passed_pre_query_template in llama_options:
         | 
|  | |
| 68 | 
             
                    MAGPIE_PRE_QUERY_TEMPLATE = "llama3"
         | 
| 69 | 
            +
                elif passed_pre_query_template in qwen_options:
         | 
| 70 | 
             
                    MAGPIE_PRE_QUERY_TEMPLATE = "qwen2"
         | 
| 71 | 
             
                else:
         | 
| 72 | 
             
                    MAGPIE_PRE_QUERY_TEMPLATE = passed_pre_query_template
         | 
|  | |
| 81 | 
             
                SFT_AVAILABLE = True
         | 
| 82 | 
             
                MAGPIE_PRE_QUERY_TEMPLATE = "qwen2"
         | 
| 83 |  | 
| 84 | 
            +
            if OPENAI_BASE_URL:
         | 
| 85 | 
             
                SFT_AVAILABLE = False
         | 
| 86 |  | 
| 87 | 
             
            if not SFT_AVAILABLE:
         | 
| 88 | 
             
                warnings.warn(
         | 
| 89 | 
            +
                    "`SFT_AVAILABLE` is set to `False`. Use Hugging Face Inference Endpoints or Ollama to generate chat data, provide a `TOKENIZER_ID` and `MAGPIE_PRE_QUERY_TEMPLATE`. You can also use `HUGGINGFACE_BASE_URL` to with vllm."
         | 
| 90 | 
             
                )
         | 
| 91 | 
             
                MAGPIE_PRE_QUERY_TEMPLATE = None
         | 
| 92 |  | 
|  | |
| 94 | 
             
            STATIC_EMBEDDING_MODEL = "minishlab/potion-base-8M"
         | 
| 95 |  | 
| 96 | 
             
            # Argilla
         | 
| 97 | 
            +
            ARGILLA_API_URL = os.getenv("ARGILLA_API_URL") or os.getenv(
         | 
| 98 | 
            +
                "ARGILLA_API_URL_SDG_REVIEWER"
         | 
| 99 | 
            +
            )
         | 
| 100 | 
            +
            ARGILLA_API_KEY = os.getenv("ARGILLA_API_KEY") or os.getenv(
         | 
| 101 | 
            +
                "ARGILLA_API_KEY_SDG_REVIEWER"
         | 
| 102 | 
            +
            )
         | 
| 103 |  | 
| 104 | 
             
            if not ARGILLA_API_URL or not ARGILLA_API_KEY:
         | 
| 105 | 
             
                warnings.warn("ARGILLA_API_URL or ARGILLA_API_KEY is not set or is empty")
         | 
    	
        src/synthetic_dataset_generator/pipelines/base.py
    CHANGED
    
    | @@ -1,4 +1,21 @@ | |
| 1 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 2 |  | 
| 3 | 
             
            TOKEN_INDEX = 0
         | 
| 4 |  | 
| @@ -8,3 +25,117 @@ def _get_next_api_key(): | |
| 8 | 
             
                api_key = API_KEYS[TOKEN_INDEX % len(API_KEYS)]
         | 
| 9 | 
             
                TOKEN_INDEX += 1
         | 
| 10 | 
             
                return api_key
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import math
         | 
| 2 | 
            +
            import random
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import gradio as gr
         | 
| 5 | 
            +
            from distilabel.llms import ClientvLLM, InferenceEndpointsLLM, OllamaLLM, OpenAILLM
         | 
| 6 | 
            +
            from distilabel.steps.tasks import TextGeneration
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from synthetic_dataset_generator.constants import (
         | 
| 9 | 
            +
                API_KEYS,
         | 
| 10 | 
            +
                DEFAULT_BATCH_SIZE,
         | 
| 11 | 
            +
                HUGGINGFACE_BASE_URL,
         | 
| 12 | 
            +
                MAGPIE_PRE_QUERY_TEMPLATE,
         | 
| 13 | 
            +
                MODEL,
         | 
| 14 | 
            +
                OLLAMA_BASE_URL,
         | 
| 15 | 
            +
                OPENAI_BASE_URL,
         | 
| 16 | 
            +
                TOKENIZER_ID,
         | 
| 17 | 
            +
                VLLM_BASE_URL,
         | 
| 18 | 
            +
            )
         | 
| 19 |  | 
| 20 | 
             
            TOKEN_INDEX = 0
         | 
| 21 |  | 
|  | |
| 25 | 
             
                api_key = API_KEYS[TOKEN_INDEX % len(API_KEYS)]
         | 
| 26 | 
             
                TOKEN_INDEX += 1
         | 
| 27 | 
             
                return api_key
         | 
| 28 | 
            +
             | 
| 29 | 
            +
             | 
| 30 | 
            +
            def _get_prompt_rewriter():
         | 
| 31 | 
            +
                generation_kwargs = {
         | 
| 32 | 
            +
                    "temperature": 1,
         | 
| 33 | 
            +
                }
         | 
| 34 | 
            +
                system_prompt = "You are a prompt rewriter. You are given a prompt and you need to rewrite it keeping the same structure but highlighting different aspects of the original without adding anything new."
         | 
| 35 | 
            +
                prompt_rewriter = TextGeneration(
         | 
| 36 | 
            +
                    llm=_get_llm(generation_kwargs=generation_kwargs),
         | 
| 37 | 
            +
                    system_prompt=system_prompt,
         | 
| 38 | 
            +
                    use_system_prompt=True,
         | 
| 39 | 
            +
                )
         | 
| 40 | 
            +
                prompt_rewriter.load()
         | 
| 41 | 
            +
                return prompt_rewriter
         | 
| 42 | 
            +
             | 
| 43 | 
            +
             | 
| 44 | 
            +
            def get_rewriten_prompts(prompt: str, num_rows: int):
         | 
| 45 | 
            +
                prompt_rewriter = _get_prompt_rewriter()
         | 
| 46 | 
            +
                # create prompt rewrites
         | 
| 47 | 
            +
                inputs = [
         | 
| 48 | 
            +
                    {"instruction": f"Original prompt: {prompt} \nRewritten prompt: "}
         | 
| 49 | 
            +
                    for i in range(math.floor(num_rows / 100))
         | 
| 50 | 
            +
                ]
         | 
| 51 | 
            +
                n_processed = 0
         | 
| 52 | 
            +
                prompt_rewrites = [prompt]
         | 
| 53 | 
            +
                while n_processed < num_rows:
         | 
| 54 | 
            +
                    batch = list(
         | 
| 55 | 
            +
                        prompt_rewriter.process(
         | 
| 56 | 
            +
                            inputs=inputs[n_processed : n_processed + DEFAULT_BATCH_SIZE]
         | 
| 57 | 
            +
                        )
         | 
| 58 | 
            +
                    )
         | 
| 59 | 
            +
                    prompt_rewrites += [entry["generation"] for entry in batch[0]]
         | 
| 60 | 
            +
                    n_processed += DEFAULT_BATCH_SIZE
         | 
| 61 | 
            +
                    random.seed(a=random.randint(0, 2**32 - 1))
         | 
| 62 | 
            +
                return prompt_rewrites
         | 
| 63 | 
            +
             | 
| 64 | 
            +
             | 
| 65 | 
            +
            def _get_llm(use_magpie_template=False, **kwargs):
         | 
| 66 | 
            +
                if OPENAI_BASE_URL:
         | 
| 67 | 
            +
                    llm = OpenAILLM(
         | 
| 68 | 
            +
                        model=MODEL,
         | 
| 69 | 
            +
                        base_url=OPENAI_BASE_URL,
         | 
| 70 | 
            +
                        api_key=_get_next_api_key(),
         | 
| 71 | 
            +
                        **kwargs,
         | 
| 72 | 
            +
                    )
         | 
| 73 | 
            +
                    if "generation_kwargs" in kwargs:
         | 
| 74 | 
            +
                        if "stop_sequences" in kwargs["generation_kwargs"]:
         | 
| 75 | 
            +
                            kwargs["generation_kwargs"]["stop"] = kwargs["generation_kwargs"][
         | 
| 76 | 
            +
                                "stop_sequences"
         | 
| 77 | 
            +
                            ]
         | 
| 78 | 
            +
                            del kwargs["generation_kwargs"]["stop_sequences"]
         | 
| 79 | 
            +
                        if "do_sample" in kwargs["generation_kwargs"]:
         | 
| 80 | 
            +
                            del kwargs["generation_kwargs"]["do_sample"]
         | 
| 81 | 
            +
                elif OLLAMA_BASE_URL:
         | 
| 82 | 
            +
                    if "generation_kwargs" in kwargs:
         | 
| 83 | 
            +
                        if "max_new_tokens" in kwargs["generation_kwargs"]:
         | 
| 84 | 
            +
                            kwargs["generation_kwargs"]["num_predict"] = kwargs[
         | 
| 85 | 
            +
                                "generation_kwargs"
         | 
| 86 | 
            +
                            ]["max_new_tokens"]
         | 
| 87 | 
            +
                            del kwargs["generation_kwargs"]["max_new_tokens"]
         | 
| 88 | 
            +
                        if "stop_sequences" in kwargs["generation_kwargs"]:
         | 
| 89 | 
            +
                            kwargs["generation_kwargs"]["stop"] = kwargs["generation_kwargs"][
         | 
| 90 | 
            +
                                "stop_sequences"
         | 
| 91 | 
            +
                            ]
         | 
| 92 | 
            +
                            del kwargs["generation_kwargs"]["stop_sequences"]
         | 
| 93 | 
            +
                        if "do_sample" in kwargs["generation_kwargs"]:
         | 
| 94 | 
            +
                            del kwargs["generation_kwargs"]["do_sample"]
         | 
| 95 | 
            +
                        options = kwargs["generation_kwargs"]
         | 
| 96 | 
            +
                        del kwargs["generation_kwargs"]
         | 
| 97 | 
            +
                        kwargs["generation_kwargs"] = {}
         | 
| 98 | 
            +
                        kwargs["generation_kwargs"]["options"] = options
         | 
| 99 | 
            +
                    llm = OllamaLLM(
         | 
| 100 | 
            +
                        model=MODEL,
         | 
| 101 | 
            +
                        host=OLLAMA_BASE_URL,
         | 
| 102 | 
            +
                        tokenizer_id=TOKENIZER_ID or MODEL,
         | 
| 103 | 
            +
                        **kwargs,
         | 
| 104 | 
            +
                    )
         | 
| 105 | 
            +
                elif HUGGINGFACE_BASE_URL:
         | 
| 106 | 
            +
                    kwargs["generation_kwargs"]["do_sample"] = True
         | 
| 107 | 
            +
                    llm = InferenceEndpointsLLM(
         | 
| 108 | 
            +
                        api_key=_get_next_api_key(),
         | 
| 109 | 
            +
                        base_url=HUGGINGFACE_BASE_URL,
         | 
| 110 | 
            +
                        tokenizer_id=TOKENIZER_ID or MODEL,
         | 
| 111 | 
            +
                        **kwargs,
         | 
| 112 | 
            +
                    )
         | 
| 113 | 
            +
                elif VLLM_BASE_URL:
         | 
| 114 | 
            +
                    if "generation_kwargs" in kwargs:
         | 
| 115 | 
            +
                        if "do_sample" in kwargs["generation_kwargs"]:
         | 
| 116 | 
            +
                            del kwargs["generation_kwargs"]["do_sample"]
         | 
| 117 | 
            +
                    llm = ClientvLLM(
         | 
| 118 | 
            +
                        base_url=VLLM_BASE_URL,
         | 
| 119 | 
            +
                        model=MODEL,
         | 
| 120 | 
            +
                        tokenizer=TOKENIZER_ID or MODEL,
         | 
| 121 | 
            +
                        api_key=_get_next_api_key(),
         | 
| 122 | 
            +
                        **kwargs,
         | 
| 123 | 
            +
                    )
         | 
| 124 | 
            +
                else:
         | 
| 125 | 
            +
                    llm = InferenceEndpointsLLM(
         | 
| 126 | 
            +
                        api_key=_get_next_api_key(),
         | 
| 127 | 
            +
                        tokenizer_id=TOKENIZER_ID or MODEL,
         | 
| 128 | 
            +
                        model_id=MODEL,
         | 
| 129 | 
            +
                        magpie_pre_query_template=MAGPIE_PRE_QUERY_TEMPLATE,
         | 
| 130 | 
            +
                        **kwargs,
         | 
| 131 | 
            +
                    )
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                return llm
         | 
| 134 | 
            +
             | 
| 135 | 
            +
             | 
| 136 | 
            +
            try:
         | 
| 137 | 
            +
                llm = _get_llm()
         | 
| 138 | 
            +
                llm.load()
         | 
| 139 | 
            +
                llm.generate([[{"content": "Hello, world!", "role": "user"}]])
         | 
| 140 | 
            +
            except Exception as e:
         | 
| 141 | 
            +
                gr.Error(f"Error loading {llm.__class__.__name__}: {e}")
         | 
    	
        src/synthetic_dataset_generator/pipelines/chat.py
    CHANGED
    
    | @@ -1,4 +1,3 @@ | |
| 1 | 
            -
            from distilabel.llms import InferenceEndpointsLLM
         | 
| 2 | 
             
            from distilabel.steps.tasks import ChatGeneration, Magpie, TextGeneration
         | 
| 3 |  | 
| 4 | 
             
            from synthetic_dataset_generator.constants import (
         | 
| @@ -7,7 +6,7 @@ from synthetic_dataset_generator.constants import ( | |
| 7 | 
             
                MAX_NUM_TOKENS,
         | 
| 8 | 
             
                MODEL,
         | 
| 9 | 
             
            )
         | 
| 10 | 
            -
            from synthetic_dataset_generator.pipelines.base import  | 
| 11 |  | 
| 12 | 
             
            INFORMATION_SEEKING_PROMPT = (
         | 
| 13 | 
             
                "You are an AI assistant designed to provide accurate and concise information on a wide"
         | 
| @@ -149,18 +148,13 @@ def _get_output_mappings(num_turns): | |
| 149 |  | 
| 150 |  | 
| 151 | 
             
            def get_prompt_generator():
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 152 | 
             
                prompt_generator = TextGeneration(
         | 
| 153 | 
            -
                    llm= | 
| 154 | 
            -
                        api_key=_get_next_api_key(),
         | 
| 155 | 
            -
                        model_id=MODEL,
         | 
| 156 | 
            -
                        tokenizer_id=MODEL,
         | 
| 157 | 
            -
                        base_url=BASE_URL,
         | 
| 158 | 
            -
                        generation_kwargs={
         | 
| 159 | 
            -
                            "temperature": 0.8,
         | 
| 160 | 
            -
                            "max_new_tokens": MAX_NUM_TOKENS,
         | 
| 161 | 
            -
                            "do_sample": True,
         | 
| 162 | 
            -
                        },
         | 
| 163 | 
            -
                    ),
         | 
| 164 | 
             
                    system_prompt=PROMPT_CREATION_PROMPT,
         | 
| 165 | 
             
                    use_system_prompt=True,
         | 
| 166 | 
             
                )
         | 
| @@ -172,38 +166,34 @@ def get_magpie_generator(system_prompt, num_turns, temperature, is_sample): | |
| 172 | 
             
                input_mappings = _get_output_mappings(num_turns)
         | 
| 173 | 
             
                output_mappings = input_mappings.copy()
         | 
| 174 | 
             
                if num_turns == 1:
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 175 | 
             
                    magpie_generator = Magpie(
         | 
| 176 | 
            -
                        llm= | 
| 177 | 
            -
                             | 
| 178 | 
            -
                            tokenizer_id=MODEL,
         | 
| 179 | 
            -
                            base_url=BASE_URL,
         | 
| 180 | 
            -
                            api_key=_get_next_api_key(),
         | 
| 181 | 
             
                            magpie_pre_query_template=MAGPIE_PRE_QUERY_TEMPLATE,
         | 
| 182 | 
            -
                             | 
| 183 | 
            -
                                "temperature": temperature,
         | 
| 184 | 
            -
                                "do_sample": True,
         | 
| 185 | 
            -
                                "max_new_tokens": 256 if is_sample else int(MAX_NUM_TOKENS * 0.25),
         | 
| 186 | 
            -
                                "stop_sequences": _STOP_SEQUENCES,
         | 
| 187 | 
            -
                            },
         | 
| 188 | 
             
                        ),
         | 
| 189 | 
             
                        n_turns=num_turns,
         | 
| 190 | 
             
                        output_mappings=output_mappings,
         | 
| 191 | 
             
                        only_instruction=True,
         | 
| 192 | 
             
                    )
         | 
| 193 | 
             
                else:
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 194 | 
             
                    magpie_generator = Magpie(
         | 
| 195 | 
            -
                        llm= | 
| 196 | 
            -
                             | 
| 197 | 
            -
                            tokenizer_id=MODEL,
         | 
| 198 | 
            -
                            base_url=BASE_URL,
         | 
| 199 | 
            -
                            api_key=_get_next_api_key(),
         | 
| 200 | 
             
                            magpie_pre_query_template=MAGPIE_PRE_QUERY_TEMPLATE,
         | 
| 201 | 
            -
                             | 
| 202 | 
            -
                                "temperature": temperature,
         | 
| 203 | 
            -
                                "do_sample": True,
         | 
| 204 | 
            -
                                "max_new_tokens": 256 if is_sample else int(MAX_NUM_TOKENS * 0.5),
         | 
| 205 | 
            -
                                "stop_sequences": _STOP_SEQUENCES,
         | 
| 206 | 
            -
                            },
         | 
| 207 | 
             
                        ),
         | 
| 208 | 
             
                        end_with_user=True,
         | 
| 209 | 
             
                        n_turns=num_turns,
         | 
| @@ -213,51 +203,25 @@ def get_magpie_generator(system_prompt, num_turns, temperature, is_sample): | |
| 213 | 
             
                return magpie_generator
         | 
| 214 |  | 
| 215 |  | 
| 216 | 
            -
            def get_prompt_rewriter():
         | 
| 217 | 
            -
                prompt_rewriter = TextGeneration(
         | 
| 218 | 
            -
                    llm=InferenceEndpointsLLM(
         | 
| 219 | 
            -
                        model_id=MODEL,
         | 
| 220 | 
            -
                        tokenizer_id=MODEL,
         | 
| 221 | 
            -
                        base_url=BASE_URL,
         | 
| 222 | 
            -
                        api_key=_get_next_api_key(),
         | 
| 223 | 
            -
                        generation_kwargs={
         | 
| 224 | 
            -
                            "temperature": 1,
         | 
| 225 | 
            -
                        },
         | 
| 226 | 
            -
                    ),
         | 
| 227 | 
            -
                )
         | 
| 228 | 
            -
                prompt_rewriter.load()
         | 
| 229 | 
            -
                return prompt_rewriter
         | 
| 230 | 
            -
             | 
| 231 | 
            -
             | 
| 232 | 
             
            def get_response_generator(system_prompt, num_turns, temperature, is_sample):
         | 
| 233 | 
             
                if num_turns == 1:
         | 
|  | |
|  | |
|  | |
|  | |
| 234 | 
             
                    response_generator = TextGeneration(
         | 
| 235 | 
            -
                        llm= | 
| 236 | 
            -
                            model_id=MODEL,
         | 
| 237 | 
            -
                            tokenizer_id=MODEL,
         | 
| 238 | 
            -
                            base_url=BASE_URL,
         | 
| 239 | 
            -
                            api_key=_get_next_api_key(),
         | 
| 240 | 
            -
                            generation_kwargs={
         | 
| 241 | 
            -
                                "temperature": temperature,
         | 
| 242 | 
            -
                                "max_new_tokens": 256 if is_sample else int(MAX_NUM_TOKENS * 0.5),
         | 
| 243 | 
            -
                            },
         | 
| 244 | 
            -
                        ),
         | 
| 245 | 
             
                        system_prompt=system_prompt,
         | 
| 246 | 
             
                        output_mappings={"generation": "completion"},
         | 
| 247 | 
             
                        input_mappings={"instruction": "prompt"},
         | 
| 248 | 
             
                    )
         | 
| 249 | 
             
                else:
         | 
|  | |
|  | |
|  | |
|  | |
| 250 | 
             
                    response_generator = ChatGeneration(
         | 
| 251 | 
            -
                        llm= | 
| 252 | 
            -
                            model_id=MODEL,
         | 
| 253 | 
            -
                            tokenizer_id=MODEL,
         | 
| 254 | 
            -
                            base_url=BASE_URL,
         | 
| 255 | 
            -
                            api_key=_get_next_api_key(),
         | 
| 256 | 
            -
                            generation_kwargs={
         | 
| 257 | 
            -
                                "temperature": temperature,
         | 
| 258 | 
            -
                                "max_new_tokens": MAX_NUM_TOKENS,
         | 
| 259 | 
            -
                            },
         | 
| 260 | 
            -
                        ),
         | 
| 261 | 
             
                        output_mappings={"generation": "completion"},
         | 
| 262 | 
             
                        input_mappings={"conversation": "messages"},
         | 
| 263 | 
             
                    )
         | 
| @@ -293,7 +257,7 @@ with Pipeline(name="sft") as pipeline: | |
| 293 | 
             
                            "max_new_tokens": {MAX_NUM_TOKENS},
         | 
| 294 | 
             
                            "stop_sequences": {_STOP_SEQUENCES}
         | 
| 295 | 
             
                        }},
         | 
| 296 | 
            -
                        api_key=os.environ[" | 
| 297 | 
             
                    ),
         | 
| 298 | 
             
                    n_turns={num_turns},
         | 
| 299 | 
             
                    num_rows={num_rows},
         | 
|  | |
|  | |
| 1 | 
             
            from distilabel.steps.tasks import ChatGeneration, Magpie, TextGeneration
         | 
| 2 |  | 
| 3 | 
             
            from synthetic_dataset_generator.constants import (
         | 
|  | |
| 6 | 
             
                MAX_NUM_TOKENS,
         | 
| 7 | 
             
                MODEL,
         | 
| 8 | 
             
            )
         | 
| 9 | 
            +
            from synthetic_dataset_generator.pipelines.base import _get_llm
         | 
| 10 |  | 
| 11 | 
             
            INFORMATION_SEEKING_PROMPT = (
         | 
| 12 | 
             
                "You are an AI assistant designed to provide accurate and concise information on a wide"
         | 
|  | |
| 148 |  | 
| 149 |  | 
| 150 | 
             
            def get_prompt_generator():
         | 
| 151 | 
            +
                generation_kwargs = {
         | 
| 152 | 
            +
                    "temperature": 0.8,
         | 
| 153 | 
            +
                    "max_new_tokens": MAX_NUM_TOKENS,
         | 
| 154 | 
            +
                    "do_sample": True,
         | 
| 155 | 
            +
                }
         | 
| 156 | 
             
                prompt_generator = TextGeneration(
         | 
| 157 | 
            +
                    llm=_get_llm(generation_kwargs=generation_kwargs),
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 158 | 
             
                    system_prompt=PROMPT_CREATION_PROMPT,
         | 
| 159 | 
             
                    use_system_prompt=True,
         | 
| 160 | 
             
                )
         | 
|  | |
| 166 | 
             
                input_mappings = _get_output_mappings(num_turns)
         | 
| 167 | 
             
                output_mappings = input_mappings.copy()
         | 
| 168 | 
             
                if num_turns == 1:
         | 
| 169 | 
            +
                    generation_kwargs = {
         | 
| 170 | 
            +
                        "temperature": temperature,
         | 
| 171 | 
            +
                        "do_sample": True,
         | 
| 172 | 
            +
                        "max_new_tokens": 256 if is_sample else int(MAX_NUM_TOKENS * 0.25),
         | 
| 173 | 
            +
                        "stop_sequences": _STOP_SEQUENCES,
         | 
| 174 | 
            +
                    }
         | 
| 175 | 
             
                    magpie_generator = Magpie(
         | 
| 176 | 
            +
                        llm=_get_llm(
         | 
| 177 | 
            +
                            generation_kwargs=generation_kwargs,
         | 
|  | |
|  | |
|  | |
| 178 | 
             
                            magpie_pre_query_template=MAGPIE_PRE_QUERY_TEMPLATE,
         | 
| 179 | 
            +
                            use_magpie_template=True,
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 180 | 
             
                        ),
         | 
| 181 | 
             
                        n_turns=num_turns,
         | 
| 182 | 
             
                        output_mappings=output_mappings,
         | 
| 183 | 
             
                        only_instruction=True,
         | 
| 184 | 
             
                    )
         | 
| 185 | 
             
                else:
         | 
| 186 | 
            +
                    generation_kwargs = {
         | 
| 187 | 
            +
                        "temperature": temperature,
         | 
| 188 | 
            +
                        "do_sample": True,
         | 
| 189 | 
            +
                        "max_new_tokens": 256 if is_sample else int(MAX_NUM_TOKENS * 0.5),
         | 
| 190 | 
            +
                        "stop_sequences": _STOP_SEQUENCES,
         | 
| 191 | 
            +
                    }
         | 
| 192 | 
             
                    magpie_generator = Magpie(
         | 
| 193 | 
            +
                        llm=_get_llm(
         | 
| 194 | 
            +
                            generation_kwargs=generation_kwargs,
         | 
|  | |
|  | |
|  | |
| 195 | 
             
                            magpie_pre_query_template=MAGPIE_PRE_QUERY_TEMPLATE,
         | 
| 196 | 
            +
                            use_magpie_template=True,
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 197 | 
             
                        ),
         | 
| 198 | 
             
                        end_with_user=True,
         | 
| 199 | 
             
                        n_turns=num_turns,
         | 
|  | |
| 203 | 
             
                return magpie_generator
         | 
| 204 |  | 
| 205 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 206 | 
             
            def get_response_generator(system_prompt, num_turns, temperature, is_sample):
         | 
| 207 | 
             
                if num_turns == 1:
         | 
| 208 | 
            +
                    generation_kwargs = {
         | 
| 209 | 
            +
                        "temperature": temperature,
         | 
| 210 | 
            +
                        "max_new_tokens": 256 if is_sample else int(MAX_NUM_TOKENS * 0.5),
         | 
| 211 | 
            +
                    }
         | 
| 212 | 
             
                    response_generator = TextGeneration(
         | 
| 213 | 
            +
                        llm=_get_llm(generation_kwargs=generation_kwargs),
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 214 | 
             
                        system_prompt=system_prompt,
         | 
| 215 | 
             
                        output_mappings={"generation": "completion"},
         | 
| 216 | 
             
                        input_mappings={"instruction": "prompt"},
         | 
| 217 | 
             
                    )
         | 
| 218 | 
             
                else:
         | 
| 219 | 
            +
                    generation_kwargs = {
         | 
| 220 | 
            +
                        "temperature": temperature,
         | 
| 221 | 
            +
                        "max_new_tokens": MAX_NUM_TOKENS,
         | 
| 222 | 
            +
                    }
         | 
| 223 | 
             
                    response_generator = ChatGeneration(
         | 
| 224 | 
            +
                        llm=_get_llm(generation_kwargs=generation_kwargs),
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 225 | 
             
                        output_mappings={"generation": "completion"},
         | 
| 226 | 
             
                        input_mappings={"conversation": "messages"},
         | 
| 227 | 
             
                    )
         | 
|  | |
| 257 | 
             
                            "max_new_tokens": {MAX_NUM_TOKENS},
         | 
| 258 | 
             
                            "stop_sequences": {_STOP_SEQUENCES}
         | 
| 259 | 
             
                        }},
         | 
| 260 | 
            +
                        api_key=os.environ["API_KEY"],
         | 
| 261 | 
             
                    ),
         | 
| 262 | 
             
                    n_turns={num_turns},
         | 
| 263 | 
             
                    num_rows={num_rows},
         | 
    	
        src/synthetic_dataset_generator/pipelines/textcat.py
    CHANGED
    
    | @@ -1,7 +1,6 @@ | |
| 1 | 
             
            import random
         | 
| 2 | 
             
            from typing import List
         | 
| 3 |  | 
| 4 | 
            -
            from distilabel.llms import InferenceEndpointsLLM, OpenAILLM
         | 
| 5 | 
             
            from distilabel.steps.tasks import (
         | 
| 6 | 
             
                GenerateTextClassificationData,
         | 
| 7 | 
             
                TextClassification,
         | 
| @@ -9,8 +8,12 @@ from distilabel.steps.tasks import ( | |
| 9 | 
             
            )
         | 
| 10 | 
             
            from pydantic import BaseModel, Field
         | 
| 11 |  | 
| 12 | 
            -
            from synthetic_dataset_generator.constants import  | 
| 13 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
| 14 | 
             
            from synthetic_dataset_generator.utils import get_preprocess_labels
         | 
| 15 |  | 
| 16 | 
             
            PROMPT_CREATION_PROMPT = """You are an AI assistant specialized in generating very precise text classification tasks for dataset creation.
         | 
| @@ -69,23 +72,10 @@ def get_prompt_generator(): | |
| 69 | 
             
                    "temperature": 0.8,
         | 
| 70 | 
             
                    "max_new_tokens": MAX_NUM_TOKENS,
         | 
| 71 | 
             
                }
         | 
| 72 | 
            -
                 | 
| 73 | 
            -
                     | 
| 74 | 
            -
             | 
| 75 | 
            -
             | 
| 76 | 
            -
                        api_key=_get_next_api_key(),
         | 
| 77 | 
            -
                        structured_output=structured_output,
         | 
| 78 | 
            -
                        generation_kwargs=generation_kwargs,
         | 
| 79 | 
            -
                    )
         | 
| 80 | 
            -
                else:
         | 
| 81 | 
            -
                    generation_kwargs["do_sample"] = True
         | 
| 82 | 
            -
                    llm = InferenceEndpointsLLM(
         | 
| 83 | 
            -
                        api_key=_get_next_api_key(),
         | 
| 84 | 
            -
                        model_id=MODEL,
         | 
| 85 | 
            -
                        base_url=BASE_URL,
         | 
| 86 | 
            -
                        structured_output=structured_output,
         | 
| 87 | 
            -
                        generation_kwargs=generation_kwargs,
         | 
| 88 | 
            -
                    )
         | 
| 89 |  | 
| 90 | 
             
                prompt_generator = TextGeneration(
         | 
| 91 | 
             
                    llm=llm,
         | 
| @@ -103,22 +93,7 @@ def get_textcat_generator(difficulty, clarity, temperature, is_sample): | |
| 103 | 
             
                    "max_new_tokens": 256 if is_sample else MAX_NUM_TOKENS,
         | 
| 104 | 
             
                    "top_p": 0.95,
         | 
| 105 | 
             
                }
         | 
| 106 | 
            -
                 | 
| 107 | 
            -
                    llm = OpenAILLM(
         | 
| 108 | 
            -
                        model=MODEL,
         | 
| 109 | 
            -
                        base_url=BASE_URL,
         | 
| 110 | 
            -
                        api_key=_get_next_api_key(),
         | 
| 111 | 
            -
                        generation_kwargs=generation_kwargs,
         | 
| 112 | 
            -
                    )
         | 
| 113 | 
            -
                else:
         | 
| 114 | 
            -
                    generation_kwargs["do_sample"] = True
         | 
| 115 | 
            -
                    llm = InferenceEndpointsLLM(
         | 
| 116 | 
            -
                        model_id=MODEL,
         | 
| 117 | 
            -
                        base_url=BASE_URL,
         | 
| 118 | 
            -
                        api_key=_get_next_api_key(),
         | 
| 119 | 
            -
                        generation_kwargs=generation_kwargs,
         | 
| 120 | 
            -
                    )
         | 
| 121 | 
            -
             | 
| 122 | 
             
                textcat_generator = GenerateTextClassificationData(
         | 
| 123 | 
             
                    llm=llm,
         | 
| 124 | 
             
                    difficulty=None if difficulty == "mixed" else difficulty,
         | 
| @@ -134,22 +109,7 @@ def get_labeller_generator(system_prompt, labels, multi_label): | |
| 134 | 
             
                    "temperature": 0.01,
         | 
| 135 | 
             
                    "max_new_tokens": MAX_NUM_TOKENS,
         | 
| 136 | 
             
                }
         | 
| 137 | 
            -
             | 
| 138 | 
            -
                if BASE_URL:
         | 
| 139 | 
            -
                    llm = OpenAILLM(
         | 
| 140 | 
            -
                        model=MODEL,
         | 
| 141 | 
            -
                        base_url=BASE_URL,
         | 
| 142 | 
            -
                        api_key=_get_next_api_key(),
         | 
| 143 | 
            -
                        generation_kwargs=generation_kwargs,
         | 
| 144 | 
            -
                    )
         | 
| 145 | 
            -
                else:
         | 
| 146 | 
            -
                    llm = InferenceEndpointsLLM(
         | 
| 147 | 
            -
                        model_id=MODEL,
         | 
| 148 | 
            -
                        base_url=BASE_URL,
         | 
| 149 | 
            -
                        api_key=_get_next_api_key(),
         | 
| 150 | 
            -
                        generation_kwargs=generation_kwargs,
         | 
| 151 | 
            -
                    )
         | 
| 152 | 
            -
             | 
| 153 | 
             
                labeller_generator = TextClassification(
         | 
| 154 | 
             
                    llm=llm,
         | 
| 155 | 
             
                    context=system_prompt,
         | 
|  | |
| 1 | 
             
            import random
         | 
| 2 | 
             
            from typing import List
         | 
| 3 |  | 
|  | |
| 4 | 
             
            from distilabel.steps.tasks import (
         | 
| 5 | 
             
                GenerateTextClassificationData,
         | 
| 6 | 
             
                TextClassification,
         | 
|  | |
| 8 | 
             
            )
         | 
| 9 | 
             
            from pydantic import BaseModel, Field
         | 
| 10 |  | 
| 11 | 
            +
            from synthetic_dataset_generator.constants import (
         | 
| 12 | 
            +
                BASE_URL,
         | 
| 13 | 
            +
                MAX_NUM_TOKENS,
         | 
| 14 | 
            +
                MODEL,
         | 
| 15 | 
            +
            )
         | 
| 16 | 
            +
            from synthetic_dataset_generator.pipelines.base import _get_llm
         | 
| 17 | 
             
            from synthetic_dataset_generator.utils import get_preprocess_labels
         | 
| 18 |  | 
| 19 | 
             
            PROMPT_CREATION_PROMPT = """You are an AI assistant specialized in generating very precise text classification tasks for dataset creation.
         | 
|  | |
| 72 | 
             
                    "temperature": 0.8,
         | 
| 73 | 
             
                    "max_new_tokens": MAX_NUM_TOKENS,
         | 
| 74 | 
             
                }
         | 
| 75 | 
            +
                llm = _get_llm(
         | 
| 76 | 
            +
                    structured_output=structured_output,
         | 
| 77 | 
            +
                    generation_kwargs=generation_kwargs,
         | 
| 78 | 
            +
                )
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 79 |  | 
| 80 | 
             
                prompt_generator = TextGeneration(
         | 
| 81 | 
             
                    llm=llm,
         | 
|  | |
| 93 | 
             
                    "max_new_tokens": 256 if is_sample else MAX_NUM_TOKENS,
         | 
| 94 | 
             
                    "top_p": 0.95,
         | 
| 95 | 
             
                }
         | 
| 96 | 
            +
                llm = _get_llm(generation_kwargs=generation_kwargs)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 97 | 
             
                textcat_generator = GenerateTextClassificationData(
         | 
| 98 | 
             
                    llm=llm,
         | 
| 99 | 
             
                    difficulty=None if difficulty == "mixed" else difficulty,
         | 
|  | |
| 109 | 
             
                    "temperature": 0.01,
         | 
| 110 | 
             
                    "max_new_tokens": MAX_NUM_TOKENS,
         | 
| 111 | 
             
                }
         | 
| 112 | 
            +
                llm = _get_llm(generation_kwargs=generation_kwargs)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 113 | 
             
                labeller_generator = TextClassification(
         | 
| 114 | 
             
                    llm=llm,
         | 
| 115 | 
             
                    context=system_prompt,
         | 
    	
        src/synthetic_dataset_generator/utils.py
    CHANGED
    
    | @@ -1,4 +1,5 @@ | |
| 1 | 
             
            import json
         | 
|  | |
| 2 | 
             
            import warnings
         | 
| 3 | 
             
            from typing import List, Optional, Union
         | 
| 4 |  | 
| @@ -55,6 +56,10 @@ def list_orgs(oauth_token: Union[OAuthToken, None] = None): | |
| 55 | 
             
                return organizations
         | 
| 56 |  | 
| 57 |  | 
|  | |
|  | |
|  | |
|  | |
| 58 | 
             
            def get_org_dropdown(oauth_token: Union[OAuthToken, None] = None):
         | 
| 59 | 
             
                if oauth_token is not None:
         | 
| 60 | 
             
                    orgs = list_orgs(oauth_token)
         | 
|  | |
| 1 | 
             
            import json
         | 
| 2 | 
            +
            import uuid
         | 
| 3 | 
             
            import warnings
         | 
| 4 | 
             
            from typing import List, Optional, Union
         | 
| 5 |  | 
|  | |
| 56 | 
             
                return organizations
         | 
| 57 |  | 
| 58 |  | 
| 59 | 
            +
            def get_random_repo_name():
         | 
| 60 | 
            +
                return f"my-distiset-{str(uuid.uuid4())[:8]}"
         | 
| 61 | 
            +
             | 
| 62 | 
            +
             | 
| 63 | 
             
            def get_org_dropdown(oauth_token: Union[OAuthToken, None] = None):
         | 
| 64 | 
             
                if oauth_token is not None:
         | 
| 65 | 
             
                    orgs = list_orgs(oauth_token)
         | 

