paraformer / modeling_paraformer.py
nguyenthanhasia's picture
Upload modeling_paraformer.py with huggingface_hub
cc3319c verified
"""
Paraformer model implementation for Hugging Face Transformers.
This module implements the Paraformer model for legal document retrieval,
based on the paper "Attentive Deep Neural Networks for Legal Document Retrieval".
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Optional, Union, Tuple
from transformers.modeling_outputs import BaseModelOutput, SequenceClassifierOutput
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
try:
from .configuration_paraformer import ParaformerConfig
except ImportError:
from configuration_paraformer import ParaformerConfig
logger = logging.get_logger(__name__)
def sparsemax(input_tensor, dim=-1):
"""
Sparsemax activation function.
Args:
input_tensor: Input tensor
dim: Dimension along which to apply sparsemax
Returns:
Sparsemax output tensor
"""
# Sort input in descending order
sorted_input, _ = torch.sort(input_tensor, dim=dim, descending=True)
# Compute cumulative sum
input_cumsum = torch.cumsum(sorted_input, dim=dim) - 1
# Create range tensor
k = torch.arange(1, input_tensor.size(dim) + 1, dtype=input_tensor.dtype, device=input_tensor.device)
if dim != -1:
shape = [1] * input_tensor.dim()
shape[dim] = -1
k = k.view(shape)
# Compute support
support = k * sorted_input > input_cumsum
# Find the largest k such that support[k] is True
support_cumsum = torch.cumsum(support.float(), dim=dim)
support_size = torch.sum(support.float(), dim=dim, keepdim=True)
# Compute tau
tau_cumsum = torch.cumsum(sorted_input * support.float(), dim=dim)
tau = (tau_cumsum - 1) / support_size
# Expand tau to match input shape
if dim != -1:
tau = tau.unsqueeze(dim)
# Apply sparsemax
output = torch.clamp(input_tensor - tau, min=0)
return output
class ParaformerAttention(nn.Module):
"""
Attention mechanism for Paraformer model.
This implements a general attention mechanism with optional sparsemax activation.
"""
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.use_sparsemax = config.use_sparsemax
# Attention layers
if config.attention_type == "general":
self.attention_weights = nn.Linear(config.hidden_size, 1, bias=False)
else:
raise ValueError(f"Unsupported attention type: {config.attention_type}")
def forward(self, query_embedding, sentence_embeddings, attention_mask=None):
"""
Apply attention mechanism.
Args:
query_embedding: Query embedding tensor [batch_size, hidden_size]
sentence_embeddings: Sentence embeddings [batch_size, num_sentences, hidden_size]
attention_mask: Mask for padding sentences [batch_size, num_sentences]
Returns:
attended_output: Weighted combination of sentence embeddings
attention_weights: Attention weights for interpretability
"""
batch_size, num_sentences, hidden_size = sentence_embeddings.shape
# Expand query embedding to match sentence embeddings
query_expanded = query_embedding.unsqueeze(1).expand(-1, num_sentences, -1)
# Compute attention scores using general attention
# Combine query and sentence embeddings
combined = query_expanded * sentence_embeddings # Element-wise multiplication
attention_scores = self.attention_weights(combined).squeeze(-1) # [batch_size, num_sentences]
# Apply attention mask if provided
if attention_mask is not None:
attention_scores = attention_scores.masked_fill(~attention_mask, float('-inf'))
# Apply sparsemax or softmax
if self.use_sparsemax:
attention_weights = sparsemax(attention_scores, dim=-1)
else:
attention_weights = F.softmax(attention_scores, dim=-1)
# Apply attention weights
attended_output = torch.sum(attention_weights.unsqueeze(-1) * sentence_embeddings.clone(), dim=1)
return attended_output, attention_weights
class ParaformerModel(PreTrainedModel):
"""
Paraformer model for legal document retrieval.
This model uses a hierarchical approach with attention mechanism to encode legal documents
and queries for relevance classification.
"""
config_class = ParaformerConfig
base_model_prefix = "paraformer"
supports_gradient_checkpointing = True
_no_split_modules = ["ParaformerAttention"]
def __init__(self, config):
super().__init__(config)
self.config = config
# Don't initialize SentenceTransformer in __init__ to avoid meta tensor issues
self._sentence_encoder = None
# Attention mechanism
self.attention = ParaformerAttention(config)
# Classifier
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.dropout = nn.Dropout(config.dropout_prob)
# Initialize weights
self.post_init()
@property
def sentence_encoder(self):
"""Lazy loading of SentenceTransformer to avoid meta tensor issues"""
if self._sentence_encoder is None:
from sentence_transformers import SentenceTransformer
self._sentence_encoder = SentenceTransformer(self.config.base_model_name)
return self._sentence_encoder
def forward(
self,
query_texts: Optional[List[str]] = None,
article_texts: Optional[List[List[str]]] = None,
labels: Optional[torch.Tensor] = None,
return_dict: Optional[bool] = None,
**kwargs
):
"""
Forward pass of the Paraformer model.
Args:
query_texts: List of query strings
article_texts: List of article sentence lists
labels: Optional labels for training
return_dict: Whether to return a dictionary
Returns:
Model outputs including logits and optional loss
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if query_texts is None or article_texts is None:
raise ValueError("Both query_texts and article_texts must be provided")
batch_size = len(query_texts)
device = next(self.parameters()).device
# Encode queries
query_embeddings = self.sentence_encoder.encode(
query_texts,
convert_to_tensor=True,
device=device
).clone() # Clone to avoid inference tensor issues
# Process articles
all_attended_outputs = []
all_attention_weights = []
for i, article in enumerate(article_texts):
if not article: # Handle empty articles
attended_output = torch.zeros(self.config.hidden_size, device=device)
attention_weights = torch.zeros(1, device=device)
else:
# Encode article sentences
sentence_embeddings = self.sentence_encoder.encode(
article,
convert_to_tensor=True,
device=device
).clone() # Clone to avoid inference tensor issues
# Add batch dimension if needed
if sentence_embeddings.dim() == 2:
sentence_embeddings = sentence_embeddings.unsqueeze(0)
# Apply attention
attended_output, attention_weights = self.attention(
query_embeddings[i:i+1],
sentence_embeddings
)
attended_output = attended_output.squeeze(0)
attention_weights = attention_weights.squeeze(0)
all_attended_outputs.append(attended_output)
all_attention_weights.append(attention_weights)
# Stack outputs
attended_outputs = torch.stack(all_attended_outputs)
# Apply dropout and classifier
attended_outputs = self.dropout(attended_outputs)
logits = self.classifier(attended_outputs)
# Compute loss if labels provided
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + (all_attention_weights,)
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=None,
attentions=torch.stack([w.unsqueeze(0) for w in all_attention_weights]) if all_attention_weights else None,
)
def get_relevance_score(self, query: str, article: List[str]) -> float:
"""
Get relevance score for a single query-article pair.
Args:
query: Query string
article: List of article sentences
Returns:
Relevance score between 0 and 1
"""
self.eval()
with torch.no_grad():
outputs = self.forward(
query_texts=[query],
article_texts=[article],
return_dict=True
)
probabilities = torch.softmax(outputs.logits, dim=-1)
relevance_score = probabilities[0, 1].item() # Probability of being relevant
return relevance_score
def predict_relevance(self, query: str, article: List[str]) -> int:
"""
Predict binary relevance for a single query-article pair.
Args:
query: Query string
article: List of article sentences
Returns:
Binary prediction (0 = not relevant, 1 = relevant)
"""
self.eval()
with torch.no_grad():
outputs = self.forward(
query_texts=[query],
article_texts=[article],
return_dict=True
)
prediction = torch.argmax(outputs.logits, dim=-1).item()
return prediction
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)