Spaces:
Build error
Build error
| import torch | |
| from torch.utils.data import Dataset, DataLoader | |
| import numpy as np | |
| import os | |
| import random | |
| from typing import Dict, List, Optional, Tuple | |
| from .audio_processor import AudioProcessor | |
| from ..configs.config import AudioConfig, Config | |
| class SpeakerDataset(Dataset): | |
| """ | |
| 说话人数据集:用于加载单个说话人的音频数据 | |
| """ | |
| def __init__( | |
| self, | |
| audio_files: List[str], | |
| audio_processor: AudioProcessor, | |
| cache_size: int = 100 # 添加缓存机制 | |
| ): | |
| self.audio_files = audio_files | |
| self.audio_processor = audio_processor | |
| self.cache = {} | |
| self.cache_size = cache_size | |
| def __len__(self) -> int: | |
| return len(self.audio_files) | |
| def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: | |
| audio_path = self.audio_files[idx] | |
| # 使用缓存机制 | |
| if audio_path in self.cache: | |
| return self.cache[audio_path] | |
| try: | |
| audio, mel_spec = self.audio_processor.preprocess_audio(audio_path) | |
| item = { | |
| 'audio': torch.FloatTensor(audio), | |
| 'mel_spec': torch.FloatTensor(mel_spec), | |
| 'file_path': audio_path | |
| } | |
| # 更新缓存 | |
| if len(self.cache) < self.cache_size: | |
| self.cache[audio_path] = item | |
| return item | |
| except Exception as e: | |
| print(f"Error processing file {audio_path}: {str(e)}") | |
| # 返回数据集中的下一个有效样本 | |
| return self.__getitem__((idx + 1) % len(self)) | |
| class VoiceDatasetManager: | |
| """ | |
| 数据集管理器:负责数据集的组织和任务采样 | |
| """ | |
| def __init__( | |
| self, | |
| root_dir: str, | |
| audio_processor: Optional[AudioProcessor] = None, | |
| config: Optional[Config] = None | |
| ): | |
| self.root_dir = root_dir | |
| self.config = config or Config() | |
| self.audio_processor = audio_processor or AudioProcessor(config=self.config.audio) | |
| self.speakers = self._scan_speakers() | |
| def _scan_speakers(self) -> Dict[str, List[str]]: | |
| speakers = {} | |
| for speaker_id in os.listdir(self.root_dir): | |
| speaker_dir = os.path.join(self.root_dir, speaker_id) | |
| if os.path.isdir(speaker_dir): | |
| audio_files = [] | |
| # 递归搜索所有子目录 | |
| for root, _, files in os.walk(speaker_dir): | |
| for file in files: | |
| if file.endswith(self.config.data.valid_audio_extensions): | |
| audio_path = os.path.join(root, file) | |
| # 验证文件是否可访问 | |
| if os.path.exists(audio_path) and os.path.getsize(audio_path) > 0: | |
| audio_files.append(audio_path) | |
| # 只保留具有足够样本的说话人 | |
| if len(audio_files) >= self.config.data.min_samples_per_speaker: | |
| speakers[speaker_id] = audio_files | |
| else: | |
| print(f"Warning: Speaker {speaker_id} has insufficient samples") | |
| return speakers | |
| def get_speaker_dataset(self, speaker_id: str) -> SpeakerDataset: | |
| """获取特定说话人的数据集""" | |
| if speaker_id not in self.speakers: | |
| raise ValueError(f"Speaker {speaker_id} not found in dataset") | |
| return SpeakerDataset( | |
| self.speakers[speaker_id], | |
| self.audio_processor, | |
| cache_size=self.config.data.cache_size | |
| ) | |
| class MetaLearningDataset(Dataset): | |
| """ | |
| 元学习数据集:用于少样本语音克隆的训练 | |
| 每次返回一个任务的数据,包含支持集和查询集 | |
| """ | |
| def __init__( | |
| self, | |
| dataset_manager: VoiceDatasetManager, | |
| config: Config | |
| ): | |
| self.dataset_manager = dataset_manager | |
| self.config = config | |
| # 验证数据集 | |
| available_speakers = [ | |
| speaker_id for speaker_id, files in dataset_manager.speakers.items() | |
| if len(files) >= (config.meta_learning.k_shot + config.meta_learning.k_query) | |
| ] | |
| if len(available_speakers) < config.meta_learning.n_way: | |
| raise ValueError( | |
| f"Not enough speakers with sufficient samples. " | |
| f"Need {config.meta_learning.n_way} speakers with " | |
| f"{config.meta_learning.k_shot + config.meta_learning.k_query} samples each, " | |
| f"but only found {len(available_speakers)}" | |
| ) | |
| self.available_speakers = available_speakers | |
| def __len__(self) -> int: | |
| return self.config.meta_learning.n_tasks | |
| def __getitem__(self, idx: int) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: | |
| """ | |
| 返回一个任务的数据 | |
| Returns: | |
| support_data: 包含支持集数据的字典 | |
| - mel_spec: [n_way*k_shot, n_mels, time] | |
| - speaker_ids: [n_way*k_shot] | |
| query_data: 包含查询集数据的字典 | |
| - mel_spec: [n_way*k_query, n_mels, time] | |
| - speaker_ids: [n_way*k_query] | |
| """ | |
| # 随机选择说话人 | |
| selected_speakers = random.sample(self.available_speakers, self.config.meta_learning.n_way) | |
| support_data = { | |
| 'mel_spec': [], | |
| 'speaker_ids': [] | |
| } | |
| query_data = { | |
| 'mel_spec': [], | |
| 'speaker_ids': [] | |
| } | |
| for speaker_idx, speaker_id in enumerate(selected_speakers): | |
| speaker_files = self.dataset_manager.speakers[speaker_id] | |
| selected_files = random.sample( | |
| speaker_files, | |
| self.config.meta_learning.k_shot + self.config.meta_learning.k_query | |
| ) | |
| for i, file_path in enumerate(selected_files): | |
| try: | |
| _, mel_spec = self.dataset_manager.audio_processor.preprocess_audio(file_path) | |
| mel_tensor = torch.FloatTensor(mel_spec) # [n_mels, time] | |
| target_dict = support_data if i < self.config.meta_learning.k_shot else query_data | |
| target_dict['mel_spec'].append(mel_tensor) | |
| target_dict['speaker_ids'].append(speaker_idx) | |
| except Exception as e: | |
| print(f"Error processing {file_path}: {str(e)}") | |
| continue | |
| # 转换为张量 | |
| for data_dict in [support_data, query_data]: | |
| if len(data_dict['mel_spec']) == 0: | |
| raise RuntimeError("No valid samples found for task") | |
| data_dict['mel_spec'] = torch.stack(data_dict['mel_spec']) | |
| data_dict['speaker_ids'] = torch.LongTensor(data_dict['speaker_ids']) | |
| return support_data, query_data | |
| def create_meta_learning_dataloader( | |
| root_dir: str, | |
| config: Optional[Config] = None, | |
| **kwargs | |
| ) -> DataLoader: | |
| """ | |
| 创建用于元学习的数据加载器 | |
| Args: | |
| root_dir: 数据集根目录 | |
| config: 配置对象 | |
| **kwargs: 其他参数 | |
| Returns: | |
| DataLoader: 元学习数据加载器 | |
| """ | |
| config = config or Config() | |
| # 更新配置 | |
| for key, value in kwargs.items(): | |
| if hasattr(config.meta_learning, key): | |
| setattr(config.meta_learning, key, value) | |
| # 创建数据集管理器 | |
| dataset_manager = VoiceDatasetManager(root_dir, config=config) | |
| # 创建数据集 | |
| dataset = MetaLearningDataset(dataset_manager, config) | |
| # 创建数据加载器 | |
| return DataLoader( | |
| dataset, | |
| batch_size=1, # 固定为1,因为每个样本已经包含了一个完整的任务 | |
| shuffle=True, | |
| num_workers=0, # 避免多进程带来的问题 | |
| pin_memory=True, | |
| collate_fn=lambda x: x[0] # 移除批次维度 | |
| ) |