|
|
""" |
|
|
Paraformer model implementation for Hugging Face Transformers. |
|
|
|
|
|
This module implements the Paraformer model for legal document retrieval, |
|
|
based on the paper "Attentive Deep Neural Networks for Legal Document Retrieval". |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from typing import List, Optional, Union, Tuple |
|
|
from transformers.modeling_outputs import BaseModelOutput, SequenceClassifierOutput |
|
|
from transformers.modeling_utils import PreTrainedModel |
|
|
from transformers.utils import logging |
|
|
|
|
|
try: |
|
|
from .configuration_paraformer import ParaformerConfig |
|
|
except ImportError: |
|
|
from configuration_paraformer import ParaformerConfig |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
|
def sparsemax(input_tensor, dim=-1): |
|
|
""" |
|
|
Sparsemax activation function. |
|
|
|
|
|
Args: |
|
|
input_tensor: Input tensor |
|
|
dim: Dimension along which to apply sparsemax |
|
|
|
|
|
Returns: |
|
|
Sparsemax output tensor |
|
|
""" |
|
|
|
|
|
sorted_input, _ = torch.sort(input_tensor, dim=dim, descending=True) |
|
|
|
|
|
|
|
|
input_cumsum = torch.cumsum(sorted_input, dim=dim) - 1 |
|
|
|
|
|
|
|
|
k = torch.arange(1, input_tensor.size(dim) + 1, dtype=input_tensor.dtype, device=input_tensor.device) |
|
|
if dim != -1: |
|
|
shape = [1] * input_tensor.dim() |
|
|
shape[dim] = -1 |
|
|
k = k.view(shape) |
|
|
|
|
|
|
|
|
support = k * sorted_input > input_cumsum |
|
|
|
|
|
|
|
|
support_cumsum = torch.cumsum(support.float(), dim=dim) |
|
|
support_size = torch.sum(support.float(), dim=dim, keepdim=True) |
|
|
|
|
|
|
|
|
tau_cumsum = torch.cumsum(sorted_input * support.float(), dim=dim) |
|
|
tau = (tau_cumsum - 1) / support_size |
|
|
|
|
|
|
|
|
if dim != -1: |
|
|
tau = tau.unsqueeze(dim) |
|
|
|
|
|
|
|
|
output = torch.clamp(input_tensor - tau, min=0) |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
class ParaformerAttention(nn.Module): |
|
|
""" |
|
|
Attention mechanism for Paraformer model. |
|
|
|
|
|
This implements a general attention mechanism with optional sparsemax activation. |
|
|
""" |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.hidden_size = config.hidden_size |
|
|
self.use_sparsemax = config.use_sparsemax |
|
|
|
|
|
|
|
|
if config.attention_type == "general": |
|
|
self.attention_weights = nn.Linear(config.hidden_size, 1, bias=False) |
|
|
else: |
|
|
raise ValueError(f"Unsupported attention type: {config.attention_type}") |
|
|
|
|
|
def forward(self, query_embedding, sentence_embeddings, attention_mask=None): |
|
|
""" |
|
|
Apply attention mechanism. |
|
|
|
|
|
Args: |
|
|
query_embedding: Query embedding tensor [batch_size, hidden_size] |
|
|
sentence_embeddings: Sentence embeddings [batch_size, num_sentences, hidden_size] |
|
|
attention_mask: Mask for padding sentences [batch_size, num_sentences] |
|
|
|
|
|
Returns: |
|
|
attended_output: Weighted combination of sentence embeddings |
|
|
attention_weights: Attention weights for interpretability |
|
|
""" |
|
|
batch_size, num_sentences, hidden_size = sentence_embeddings.shape |
|
|
|
|
|
|
|
|
query_expanded = query_embedding.unsqueeze(1).expand(-1, num_sentences, -1) |
|
|
|
|
|
|
|
|
|
|
|
combined = query_expanded * sentence_embeddings |
|
|
attention_scores = self.attention_weights(combined).squeeze(-1) |
|
|
|
|
|
|
|
|
if attention_mask is not None: |
|
|
attention_scores = attention_scores.masked_fill(~attention_mask, float('-inf')) |
|
|
|
|
|
|
|
|
if self.use_sparsemax: |
|
|
attention_weights = sparsemax(attention_scores, dim=-1) |
|
|
else: |
|
|
attention_weights = F.softmax(attention_scores, dim=-1) |
|
|
|
|
|
|
|
|
attended_output = torch.sum(attention_weights.unsqueeze(-1) * sentence_embeddings.clone(), dim=1) |
|
|
|
|
|
return attended_output, attention_weights |
|
|
|
|
|
|
|
|
class ParaformerModel(PreTrainedModel): |
|
|
""" |
|
|
Paraformer model for legal document retrieval. |
|
|
|
|
|
This model uses a hierarchical approach with attention mechanism to encode legal documents |
|
|
and queries for relevance classification. |
|
|
""" |
|
|
|
|
|
config_class = ParaformerConfig |
|
|
base_model_prefix = "paraformer" |
|
|
supports_gradient_checkpointing = True |
|
|
_no_split_modules = ["ParaformerAttention"] |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.config = config |
|
|
|
|
|
|
|
|
self._sentence_encoder = None |
|
|
|
|
|
|
|
|
self.attention = ParaformerAttention(config) |
|
|
|
|
|
|
|
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels) |
|
|
self.dropout = nn.Dropout(config.dropout_prob) |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
@property |
|
|
def sentence_encoder(self): |
|
|
"""Lazy loading of SentenceTransformer to avoid meta tensor issues""" |
|
|
if self._sentence_encoder is None: |
|
|
from sentence_transformers import SentenceTransformer |
|
|
self._sentence_encoder = SentenceTransformer(self.config.base_model_name) |
|
|
return self._sentence_encoder |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
query_texts: Optional[List[str]] = None, |
|
|
article_texts: Optional[List[List[str]]] = None, |
|
|
labels: Optional[torch.Tensor] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
**kwargs |
|
|
): |
|
|
""" |
|
|
Forward pass of the Paraformer model. |
|
|
|
|
|
Args: |
|
|
query_texts: List of query strings |
|
|
article_texts: List of article sentence lists |
|
|
labels: Optional labels for training |
|
|
return_dict: Whether to return a dictionary |
|
|
|
|
|
Returns: |
|
|
Model outputs including logits and optional loss |
|
|
""" |
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
if query_texts is None or article_texts is None: |
|
|
raise ValueError("Both query_texts and article_texts must be provided") |
|
|
|
|
|
batch_size = len(query_texts) |
|
|
device = next(self.parameters()).device |
|
|
|
|
|
|
|
|
query_embeddings = self.sentence_encoder.encode( |
|
|
query_texts, |
|
|
convert_to_tensor=True, |
|
|
device=device |
|
|
).clone() |
|
|
|
|
|
|
|
|
all_attended_outputs = [] |
|
|
all_attention_weights = [] |
|
|
|
|
|
for i, article in enumerate(article_texts): |
|
|
if not article: |
|
|
attended_output = torch.zeros(self.config.hidden_size, device=device) |
|
|
attention_weights = torch.zeros(1, device=device) |
|
|
else: |
|
|
|
|
|
sentence_embeddings = self.sentence_encoder.encode( |
|
|
article, |
|
|
convert_to_tensor=True, |
|
|
device=device |
|
|
).clone() |
|
|
|
|
|
|
|
|
if sentence_embeddings.dim() == 2: |
|
|
sentence_embeddings = sentence_embeddings.unsqueeze(0) |
|
|
|
|
|
|
|
|
attended_output, attention_weights = self.attention( |
|
|
query_embeddings[i:i+1], |
|
|
sentence_embeddings |
|
|
) |
|
|
attended_output = attended_output.squeeze(0) |
|
|
attention_weights = attention_weights.squeeze(0) |
|
|
|
|
|
all_attended_outputs.append(attended_output) |
|
|
all_attention_weights.append(attention_weights) |
|
|
|
|
|
|
|
|
attended_outputs = torch.stack(all_attended_outputs) |
|
|
|
|
|
|
|
|
attended_outputs = self.dropout(attended_outputs) |
|
|
logits = self.classifier(attended_outputs) |
|
|
|
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
loss_fct = nn.CrossEntropyLoss() |
|
|
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) |
|
|
|
|
|
if not return_dict: |
|
|
output = (logits,) + (all_attention_weights,) |
|
|
return ((loss,) + output) if loss is not None else output |
|
|
|
|
|
return SequenceClassifierOutput( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
hidden_states=None, |
|
|
attentions=torch.stack([w.unsqueeze(0) for w in all_attention_weights]) if all_attention_weights else None, |
|
|
) |
|
|
|
|
|
def get_relevance_score(self, query: str, article: List[str]) -> float: |
|
|
""" |
|
|
Get relevance score for a single query-article pair. |
|
|
|
|
|
Args: |
|
|
query: Query string |
|
|
article: List of article sentences |
|
|
|
|
|
Returns: |
|
|
Relevance score between 0 and 1 |
|
|
""" |
|
|
self.eval() |
|
|
with torch.no_grad(): |
|
|
outputs = self.forward( |
|
|
query_texts=[query], |
|
|
article_texts=[article], |
|
|
return_dict=True |
|
|
) |
|
|
|
|
|
probabilities = torch.softmax(outputs.logits, dim=-1) |
|
|
relevance_score = probabilities[0, 1].item() |
|
|
|
|
|
return relevance_score |
|
|
|
|
|
def predict_relevance(self, query: str, article: List[str]) -> int: |
|
|
""" |
|
|
Predict binary relevance for a single query-article pair. |
|
|
|
|
|
Args: |
|
|
query: Query string |
|
|
article: List of article sentences |
|
|
|
|
|
Returns: |
|
|
Binary prediction (0 = not relevant, 1 = relevant) |
|
|
""" |
|
|
self.eval() |
|
|
with torch.no_grad(): |
|
|
outputs = self.forward( |
|
|
query_texts=[query], |
|
|
article_texts=[article], |
|
|
return_dict=True |
|
|
) |
|
|
|
|
|
prediction = torch.argmax(outputs.logits, dim=-1).item() |
|
|
|
|
|
return prediction |
|
|
|
|
|
def _init_weights(self, module): |
|
|
"""Initialize the weights""" |
|
|
if isinstance(module, nn.Linear): |
|
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
|
|
if module.bias is not None: |
|
|
module.bias.data.zero_() |
|
|
elif isinstance(module, nn.LayerNorm): |
|
|
module.bias.data.zero_() |
|
|
module.weight.data.fill_(1.0) |
|
|
|
|
|
|