#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Sun Apr 20 12:51:05 2025 Prompt weighting for WanVideo Adapted and heavily modified from https://github.com/xhinker/sd_embed License: Apache 2.0 @author: blyss """ from transformers import T5Model import torch import re from typing import Tuple, List, Union from blissful_tuner.utils import BlissfulLogger logger = BlissfulLogger(__name__, "#8e00ed") class MiniT5Wrapper(): """A mini wrapper for the T5 to make managing prompt weighting in Musubi easier""" def __init__(self, device: torch.device, dtype: torch.dtype, t5: T5Model): self.device = device self.dtype = dtype self.t5 = t5 self.model = t5.model self.times_called = 0 def __call__( self, prompt: Union[str, List[str]], device: torch.device, max_len: int = None ) -> List[torch.Tensor]: if isinstance(prompt, list): if len(prompt) != 1: raise ValueError("MiniT5Wrapper expects a single prompt at a time (wrapped as a list). Got multiple prompts.") prompt = prompt[0] if self.times_called == 0: # Only print this notice once even if called multiple times logger.info("Weighting prompts...") # Split positive prompts and process each with weights prompts_raw = [p.strip() for p in prompt.split('|')] prompts = [] all_weights = [] for p in prompts_raw: cleaned_prompt, weights = self.parse_prompt_weights(p) prompts.append(cleaned_prompt) all_weights.append(weights) context = self.t5(prompts, device) # Apply weights to embeddings if any were extracted for i, weights in enumerate(all_weights): for text, weight in weights.items(): logger.info(f"Applying weight ({weight}) to promptchunk: '{text}'") if len(weights) > 0: context[i] = context[i] * weight self.times_called += 1 return context def parse_prompt_weights(self, prompt: str) -> Tuple[str, dict]: """Extract text and weights from prompts with (text:weight) format""" # Parse all instances of (text:weight) in the prompt pattern = r'\((.*?):([\d\.]+)\)' matches = re.findall(pattern, prompt) # Replace each match with just the text part cleaned_prompt = prompt weights = {} for match in matches: text, weight = match orig_text = f"({text}:{weight})" cleaned_prompt = cleaned_prompt.replace(orig_text, text) weights[text] = float(weight) return cleaned_prompt, weights