Auto 클래스
많은 경우, 사용하려는 아키텍처는 from_pretrained()
메소드에서 제공하는 사전 훈련된 모델의 이름이나 경로로부터 유추할 수 있습니다. AutoClasses는 이 작업을 위해 존재하며, 사전 학습된 모델 가중치/구성/단어사전에 대한 이름/경로를 제공하면 자동으로 관련 모델을 가져오도록 도와줍니다.
AutoConfig, AutoModel, AutoTokenizer 중 하나를 인스턴스화하면 해당 아키텍처의 클래스를 직접 생성합니다. 예를 들어,
model = AutoModel.from_pretrained("google-bert/bert-base-cased")
위 코드는 BertModel의 인스턴스인 모델을 생성합니다.
각 작업에 대해 하나의 AutoModel
클래스가 있으며, 각각의 백엔드(PyTorch, TensorFlow 또는 Flax)에 해당하는 클래스가 존재합니다.
자동 클래스 확장
각 자동 클래스는 사용자의 커스텀 클래스로 확장될 수 있는 메소드를 가지고 있습니다. 예를 들어, NewModel
이라는 커스텀 모델 클래스를 정의했다면, NewModelConfig
를 준비한 후 다음과 같이 자동 클래스에 추가할 수 있습니다:
from transformers import AutoConfig, AutoModel
AutoConfig.register("new-model", NewModelConfig)
AutoModel.register(NewModelConfig, NewModel)
이후에는 일반적으로 자동 클래스를 사용하는 것처럼 사용할 수 있습니다!
만약 NewModelConfig
가 PretrainedConfig의 서브클래스라면, 해당 model_type
속성이 등록할 때 사용하는 키(여기서는 "new-model"
)와 동일하게 설정되어 있는지 확인하세요.
마찬가지로, NewModel
이 PreTrainedModel의 서브클래스라면, 해당 config_class
속성이 등록할 때 사용하는 클래스(여기서는 NewModelConfig
)와 동일하게 설정되어 있는지 확인하세요.
AutoConfig
This is a generic configuration class that will be instantiated as one of the configuration classes of the library when created with the from_pretrained() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_pretrained
< source >( pretrained_model_name_or_path **kwargs )
Parameters
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model configuration hosted inside a model repo on huggingface.co.
- A path to a directory containing a configuration file saved using the
save_pretrained() method, or the save_pretrained() method,
e.g.,
./my_model_directory/
. - A path or url to a saved configuration JSON file, e.g.,
./my_model_directory/configuration.json
.
- cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download the model weights and configuration files and override the cached versions if they exist. - resume_download — Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. - revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - return_unused_kwargs (
bool
, optional, defaults toFalse
) — IfFalse
, then this function returns just the final configuration object.If
True
, then this functions returns aTuple(config, unused_kwargs)
where unused_kwargs is a dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e., the part ofkwargs
which has not been used to updateconfig
and is otherwise ignored. - trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. - kwargs(additional keyword arguments, optional) —
The values in kwargs of any keys which are configuration attributes will be used to override the loaded
values. Behavior concerning key/value pairs whose keys are not configuration attributes is controlled
by the
return_unused_kwargs
keyword parameter.
Instantiate one of the configuration classes of the library from a pretrained model configuration.
The configuration class to instantiate is selected based on the model_type
property of the config object that
is loaded, or when it’s missing, by falling back to using pattern matching on pretrained_model_name_or_path
:
- albert —
AlbertConfig
(ALBERT model) - align —
AlignConfig
(ALIGN model) - altclip — AltCLIPConfig (AltCLIP model)
- aria —
AriaConfig
(Aria model) - aria_text —
AriaTextConfig
(AriaText model) - audio-spectrogram-transformer —
ASTConfig
(Audio Spectrogram Transformer model) - autoformer — AutoformerConfig (Autoformer model)
- bamba —
BambaConfig
(Bamba model) - bark —
BarkConfig
(Bark model) - bart — BartConfig (BART model)
- beit —
BeitConfig
(BEiT model) - bert — BertConfig (BERT model)
- bert-generation —
BertGenerationConfig
(Bert Generation model) - big_bird —
BigBirdConfig
(BigBird model) - bigbird_pegasus —
BigBirdPegasusConfig
(BigBird-Pegasus model) - biogpt — BioGptConfig (BioGpt model)
- bit —
BitConfig
(BiT model) - blenderbot —
BlenderbotConfig
(Blenderbot model) - blenderbot-small —
BlenderbotSmallConfig
(BlenderbotSmall model) - blip — BlipConfig (BLIP model)
- blip-2 — Blip2Config (BLIP-2 model)
- bloom —
BloomConfig
(BLOOM model) - bridgetower —
BridgeTowerConfig
(BridgeTower model) - bros —
BrosConfig
(BROS model) - camembert —
CamembertConfig
(CamemBERT model) - canine —
CanineConfig
(CANINE model) - chameleon — ChameleonConfig (Chameleon model)
- chinese_clip —
ChineseCLIPConfig
(Chinese-CLIP model) - chinese_clip_vision_model —
ChineseCLIPVisionConfig
(ChineseCLIPVisionModel model) - clap —
ClapConfig
(CLAP model) - clip — CLIPConfig (CLIP model)
- clip_text_model — CLIPTextConfig (CLIPTextModel model)
- clip_vision_model — CLIPVisionConfig (CLIPVisionModel model)
- clipseg —
CLIPSegConfig
(CLIPSeg model) - clvp —
ClvpConfig
(CLVP model) - code_llama — LlamaConfig (CodeLlama model)
- codegen —
CodeGenConfig
(CodeGen model) - cohere — CohereConfig (Cohere model)
- cohere2 —
Cohere2Config
(Cohere2 model) - colpali —
ColPaliConfig
(ColPali model) - conditional_detr —
ConditionalDetrConfig
(Conditional DETR model) - convbert — ConvBertConfig (ConvBERT model)
- convnext —
ConvNextConfig
(ConvNeXT model) - convnextv2 —
ConvNextV2Config
(ConvNeXTV2 model) - cpmant —
CpmAntConfig
(CPM-Ant model) - ctrl —
CTRLConfig
(CTRL model) - cvt —
CvtConfig
(CvT model) - dac —
DacConfig
(DAC model) - data2vec-audio —
Data2VecAudioConfig
(Data2VecAudio model) - data2vec-text —
Data2VecTextConfig
(Data2VecText model) - data2vec-vision —
Data2VecVisionConfig
(Data2VecVision model) - dbrx — DbrxConfig (DBRX model)
- deberta — DebertaConfig (DeBERTa model)
- deberta-v2 — DebertaV2Config (DeBERTa-v2 model)
- decision_transformer —
DecisionTransformerConfig
(Decision Transformer model) - deformable_detr —
DeformableDetrConfig
(Deformable DETR model) - deit —
DeiTConfig
(DeiT model) - depth_anything —
DepthAnythingConfig
(Depth Anything model) - deta —
DetaConfig
(DETA model) - detr —
DetrConfig
(DETR model) - diffllama —
DiffLlamaConfig
(DiffLlama model) - dinat —
DinatConfig
(DiNAT model) - dinov2 —
Dinov2Config
(DINOv2 model) - dinov2_with_registers —
Dinov2WithRegistersConfig
(DINOv2 with Registers model) - distilbert —
DistilBertConfig
(DistilBERT model) - donut-swin —
DonutSwinConfig
(DonutSwin model) - dpr —
DPRConfig
(DPR model) - dpt —
DPTConfig
(DPT model) - efficientformer —
EfficientFormerConfig
(EfficientFormer model) - efficientnet —
EfficientNetConfig
(EfficientNet model) - electra —
ElectraConfig
(ELECTRA model) - emu3 —
Emu3Config
(Emu3 model) - encodec —
EncodecConfig
(EnCodec model) - encoder-decoder — EncoderDecoderConfig (Encoder decoder model)
- ernie —
ErnieConfig
(ERNIE model) - ernie_m —
ErnieMConfig
(ErnieM model) - esm — EsmConfig (ESM model)
- falcon —
FalconConfig
(Falcon model) - falcon_mamba —
FalconMambaConfig
(FalconMamba model) - fastspeech2_conformer —
FastSpeech2ConformerConfig
(FastSpeech2Conformer model) - flaubert —
FlaubertConfig
(FlauBERT model) - flava —
FlavaConfig
(FLAVA model) - fnet —
FNetConfig
(FNet model) - focalnet —
FocalNetConfig
(FocalNet model) - fsmt —
FSMTConfig
(FairSeq Machine-Translation model) - funnel —
FunnelConfig
(Funnel Transformer model) - fuyu —
FuyuConfig
(Fuyu model) - gemma — GemmaConfig (Gemma model)
- gemma2 — Gemma2Config (Gemma2 model)
- git —
GitConfig
(GIT model) - glm —
GlmConfig
(GLM model) - glpn —
GLPNConfig
(GLPN model) - gpt-sw3 —
GPT2Config
(GPT-Sw3 model) - gpt2 —
GPT2Config
(OpenAI GPT-2 model) - gpt_bigcode —
GPTBigCodeConfig
(GPTBigCode model) - gpt_neo —
GPTNeoConfig
(GPT Neo model) - gpt_neox —
GPTNeoXConfig
(GPT NeoX model) - gpt_neox_japanese — GPTNeoXJapaneseConfig (GPT NeoX Japanese model)
- gptj —
GPTJConfig
(GPT-J model) - gptsan-japanese —
GPTSanJapaneseConfig
(GPTSAN-japanese model) - granite —
GraniteConfig
(Granite model) - granitemoe —
GraniteMoeConfig
(GraniteMoeMoe model) - graphormer — GraphormerConfig (Graphormer model)
- grounding-dino —
GroundingDinoConfig
(Grounding DINO model) - groupvit —
GroupViTConfig
(GroupViT model) - helium —
HeliumConfig
(Helium model) - hiera —
HieraConfig
(Hiera model) - hubert —
HubertConfig
(Hubert model) - ibert —
IBertConfig
(I-BERT model) - idefics —
IdeficsConfig
(IDEFICS model) - idefics2 —
Idefics2Config
(Idefics2 model) - idefics3 —
Idefics3Config
(Idefics3 model) - idefics3_vision —
Idefics3VisionConfig
(Idefics3VisionTransformer model) - ijepa —
IJepaConfig
(I-JEPA model) - imagegpt —
ImageGPTConfig
(ImageGPT model) - informer — InformerConfig (Informer model)
- instructblip —
InstructBlipConfig
(InstructBLIP model) - instructblipvideo —
InstructBlipVideoConfig
(InstructBlipVideo model) - jamba —
JambaConfig
(Jamba model) - jetmoe —
JetMoeConfig
(JetMoe model) - jukebox —
JukeboxConfig
(Jukebox model) - kosmos-2 —
Kosmos2Config
(KOSMOS-2 model) - layoutlm —
LayoutLMConfig
(LayoutLM model) - layoutlmv2 —
LayoutLMv2Config
(LayoutLMv2 model) - layoutlmv3 —
LayoutLMv3Config
(LayoutLMv3 model) - led —
LEDConfig
(LED model) - levit —
LevitConfig
(LeViT model) - lilt —
LiltConfig
(LiLT model) - llama — LlamaConfig (LLaMA model)
- llava —
LlavaConfig
(LLaVa model) - llava_next —
LlavaNextConfig
(LLaVA-NeXT model) - llava_next_video —
LlavaNextVideoConfig
(LLaVa-NeXT-Video model) - llava_onevision —
LlavaOnevisionConfig
(LLaVA-Onevision model) - longformer —
LongformerConfig
(Longformer model) - longt5 —
LongT5Config
(LongT5 model) - luke —
LukeConfig
(LUKE model) - lxmert —
LxmertConfig
(LXMERT model) - m2m_100 —
M2M100Config
(M2M100 model) - mamba — MambaConfig (Mamba model)
- mamba2 — Mamba2Config (mamba2 model)
- marian — MarianConfig (Marian model)
- markuplm —
MarkupLMConfig
(MarkupLM model) - mask2former —
Mask2FormerConfig
(Mask2Former model) - maskformer —
MaskFormerConfig
(MaskFormer model) - maskformer-swin —
MaskFormerSwinConfig
(MaskFormerSwin model) - mbart —
MBartConfig
(mBART model) - mctct —
MCTCTConfig
(M-CTC-T model) - mega —
MegaConfig
(MEGA model) - megatron-bert —
MegatronBertConfig
(Megatron-BERT model) - mgp-str —
MgpstrConfig
(MGP-STR model) - mimi —
MimiConfig
(Mimi model) - mistral — MistralConfig (Mistral model)
- mixtral —
MixtralConfig
(Mixtral model) - mllama —
MllamaConfig
(Mllama model) - mobilebert —
MobileBertConfig
(MobileBERT model) - mobilenet_v1 —
MobileNetV1Config
(MobileNetV1 model) - mobilenet_v2 —
MobileNetV2Config
(MobileNetV2 model) - mobilevit —
MobileViTConfig
(MobileViT model) - mobilevitv2 —
MobileViTV2Config
(MobileViTV2 model) - modernbert —
ModernBertConfig
(ModernBERT model) - moonshine —
MoonshineConfig
(Moonshine model) - moshi —
MoshiConfig
(Moshi model) - mpnet —
MPNetConfig
(MPNet model) - mpt —
MptConfig
(MPT model) - mra —
MraConfig
(MRA model) - mt5 —
MT5Config
(MT5 model) - musicgen —
MusicgenConfig
(MusicGen model) - musicgen_melody —
MusicgenMelodyConfig
(MusicGen Melody model) - mvp —
MvpConfig
(MVP model) - nat —
NatConfig
(NAT model) - nemotron —
NemotronConfig
(Nemotron model) - nezha —
NezhaConfig
(Nezha model) - nllb-moe —
NllbMoeConfig
(NLLB-MOE model) - nougat —
VisionEncoderDecoderConfig
(Nougat model) - nystromformer —
NystromformerConfig
(Nyströmformer model) - olmo —
OlmoConfig
(OLMo model) - olmo2 —
Olmo2Config
(OLMo2 model) - olmoe —
OlmoeConfig
(OLMoE model) - omdet-turbo —
OmDetTurboConfig
(OmDet-Turbo model) - oneformer —
OneFormerConfig
(OneFormer model) - open-llama —
OpenLlamaConfig
(OpenLlama model) - openai-gpt — OpenAIGPTConfig (OpenAI GPT model)
- opt —
OPTConfig
(OPT model) - owlv2 —
Owlv2Config
(OWLv2 model) - owlvit —
OwlViTConfig
(OWL-ViT model) - paligemma — PaliGemmaConfig (PaliGemma model)
- patchtsmixer — PatchTSMixerConfig (PatchTSMixer model)
- patchtst — PatchTSTConfig (PatchTST model)
- pegasus —
PegasusConfig
(Pegasus model) - pegasus_x —
PegasusXConfig
(PEGASUS-X model) - perceiver —
PerceiverConfig
(Perceiver model) - persimmon —
PersimmonConfig
(Persimmon model) - phi —
PhiConfig
(Phi model) - phi3 —
Phi3Config
(Phi3 model) - phimoe —
PhimoeConfig
(Phimoe model) - pix2struct —
Pix2StructConfig
(Pix2Struct model) - pixtral —
PixtralVisionConfig
(Pixtral model) - plbart —
PLBartConfig
(PLBart model) - poolformer —
PoolFormerConfig
(PoolFormer model) - pop2piano —
Pop2PianoConfig
(Pop2Piano model) - prophetnet —
ProphetNetConfig
(ProphetNet model) - pvt —
PvtConfig
(PVT model) - pvt_v2 —
PvtV2Config
(PVTv2 model) - qdqbert —
QDQBertConfig
(QDQBert model) - qwen2 —
Qwen2Config
(Qwen2 model) - qwen2_audio —
Qwen2AudioConfig
(Qwen2Audio model) - qwen2_audio_encoder —
Qwen2AudioEncoderConfig
(Qwen2AudioEncoder model) - qwen2_moe —
Qwen2MoeConfig
(Qwen2MoE model) - qwen2_vl —
Qwen2VLConfig
(Qwen2VL model) - rag — RagConfig (RAG model)
- realm —
RealmConfig
(REALM model) - recurrent_gemma —
RecurrentGemmaConfig
(RecurrentGemma model) - reformer —
ReformerConfig
(Reformer model) - regnet —
RegNetConfig
(RegNet model) - rembert —
RemBertConfig
(RemBERT model) - resnet —
ResNetConfig
(ResNet model) - retribert —
RetriBertConfig
(RetriBERT model) - roberta —
RobertaConfig
(RoBERTa model) - roberta-prelayernorm —
RobertaPreLayerNormConfig
(RoBERTa-PreLayerNorm model) - roc_bert —
RoCBertConfig
(RoCBert model) - roformer —
RoFormerConfig
(RoFormer model) - rt_detr —
RTDetrConfig
(RT-DETR model) - rt_detr_resnet —
RTDetrResNetConfig
(RT-DETR-ResNet model) - rwkv —
RwkvConfig
(RWKV model) - sam —
SamConfig
(SAM model) - seamless_m4t —
SeamlessM4TConfig
(SeamlessM4T model) - seamless_m4t_v2 —
SeamlessM4Tv2Config
(SeamlessM4Tv2 model) - segformer —
SegformerConfig
(SegFormer model) - seggpt —
SegGptConfig
(SegGPT model) - sew —
SEWConfig
(SEW model) - sew-d —
SEWDConfig
(SEW-D model) - siglip —
SiglipConfig
(SigLIP model) - siglip_vision_model —
SiglipVisionConfig
(SiglipVisionModel model) - speech-encoder-decoder —
SpeechEncoderDecoderConfig
(Speech Encoder decoder model) - speech_to_text —
Speech2TextConfig
(Speech2Text model) - speech_to_text_2 —
Speech2Text2Config
(Speech2Text2 model) - speecht5 —
SpeechT5Config
(SpeechT5 model) - splinter —
SplinterConfig
(Splinter model) - squeezebert —
SqueezeBertConfig
(SqueezeBERT model) - stablelm —
StableLmConfig
(StableLm model) - starcoder2 —
Starcoder2Config
(Starcoder2 model) - superpoint —
SuperPointConfig
(SuperPoint model) - swiftformer —
SwiftFormerConfig
(SwiftFormer model) - swin — SwinConfig (Swin Transformer model)
- swin2sr — Swin2SRConfig (Swin2SR model)
- swinv2 — Swinv2Config (Swin Transformer V2 model)
- switch_transformers —
SwitchTransformersConfig
(SwitchTransformers model) - t5 —
T5Config
(T5 model) - table-transformer —
TableTransformerConfig
(Table Transformer model) - tapas —
TapasConfig
(TAPAS model) - textnet —
TextNetConfig
(TextNet model) - time_series_transformer — TimeSeriesTransformerConfig (Time Series Transformer model)
- timesformer — TimesformerConfig (TimeSformer model)
- timm_backbone —
TimmBackboneConfig
(TimmBackbone model) - timm_wrapper —
TimmWrapperConfig
(TimmWrapperModel model) - trajectory_transformer — TrajectoryTransformerConfig (Trajectory Transformer model)
- transfo-xl —
TransfoXLConfig
(Transformer-XL model) - trocr —
TrOCRConfig
(TrOCR model) - tvlt —
TvltConfig
(TVLT model) - tvp —
TvpConfig
(TVP model) - udop —
UdopConfig
(UDOP model) - umt5 —
UMT5Config
(UMT5 model) - unispeech —
UniSpeechConfig
(UniSpeech model) - unispeech-sat —
UniSpeechSatConfig
(UniSpeechSat model) - univnet —
UnivNetConfig
(UnivNet model) - upernet —
UperNetConfig
(UPerNet model) - van —
VanConfig
(VAN model) - video_llava —
VideoLlavaConfig
(VideoLlava model) - videomae —
VideoMAEConfig
(VideoMAE model) - vilt —
ViltConfig
(ViLT model) - vipllava —
VipLlavaConfig
(VipLlava model) - vision-encoder-decoder —
VisionEncoderDecoderConfig
(Vision Encoder decoder model) - vision-text-dual-encoder —
VisionTextDualEncoderConfig
(VisionTextDualEncoder model) - visual_bert —
VisualBertConfig
(VisualBERT model) - vit — ViTConfig (ViT model)
- vit_hybrid —
ViTHybridConfig
(ViT Hybrid model) - vit_mae —
ViTMAEConfig
(ViTMAE model) - vit_msn —
ViTMSNConfig
(ViTMSN model) - vitdet —
VitDetConfig
(VitDet model) - vitmatte —
VitMatteConfig
(ViTMatte model) - vitpose —
VitPoseConfig
(VitPose model) - vitpose_backbone —
VitPoseBackboneConfig
(VitPoseBackbone model) - vits —
VitsConfig
(VITS model) - vivit — VivitConfig (ViViT model)
- wav2vec2 —
Wav2Vec2Config
(Wav2Vec2 model) - wav2vec2-bert —
Wav2Vec2BertConfig
(Wav2Vec2-BERT model) - wav2vec2-conformer —
Wav2Vec2ConformerConfig
(Wav2Vec2-Conformer model) - wavlm —
WavLMConfig
(WavLM model) - whisper — WhisperConfig (Whisper model)
- xclip —
XCLIPConfig
(X-CLIP model) - xglm —
XGLMConfig
(XGLM model) - xlm —
XLMConfig
(XLM model) - xlm-prophetnet —
XLMProphetNetConfig
(XLM-ProphetNet model) - xlm-roberta —
XLMRobertaConfig
(XLM-RoBERTa model) - xlm-roberta-xl —
XLMRobertaXLConfig
(XLM-RoBERTa-XL model) - xlnet —
XLNetConfig
(XLNet model) - xmod —
XmodConfig
(X-MOD model) - yolos —
YolosConfig
(YOLOS model) - yoso —
YosoConfig
(YOSO model) - zamba —
ZambaConfig
(Zamba model) - zoedepth —
ZoeDepthConfig
(ZoeDepth model)
Examples:
>>> from transformers import AutoConfig
>>> # Download configuration from huggingface.co and cache.
>>> config = AutoConfig.from_pretrained("google-bert/bert-base-uncased")
>>> # Download configuration from huggingface.co (user-uploaded) and cache.
>>> config = AutoConfig.from_pretrained("dbmdz/bert-base-german-cased")
>>> # If configuration file is in a directory (e.g., was saved using *save_pretrained('./test/saved_model/')*).
>>> config = AutoConfig.from_pretrained("./test/bert_saved_model/")
>>> # Load a specific configuration file.
>>> config = AutoConfig.from_pretrained("./test/bert_saved_model/my_configuration.json")
>>> # Change some config attributes when loading a pretrained config.
>>> config = AutoConfig.from_pretrained("google-bert/bert-base-uncased", output_attentions=True, foo=False)
>>> config.output_attentions
True
>>> config, unused_kwargs = AutoConfig.from_pretrained(
... "google-bert/bert-base-uncased", output_attentions=True, foo=False, return_unused_kwargs=True
... )
>>> config.output_attentions
True
>>> unused_kwargs
{'foo': False}
register
< source >( model_type config exist_ok = False )
Parameters
- model_type (
str
) — The model type like “bert” or “gpt”. - config (PretrainedConfig) — The config to register.
Register a new configuration for this class.
AutoTokenizer
This is a generic tokenizer class that will be instantiated as one of the tokenizer classes of the library when created with the AutoTokenizer.from_pretrained() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_pretrained
< source >( pretrained_model_name_or_path *inputs **kwargs )
Parameters
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a predefined tokenizer hosted inside a model repo on huggingface.co.
- A path to a directory containing vocabulary files required by the tokenizer, for instance saved
using the save_pretrained() method, e.g.,
./my_model_directory/
. - A path or url to a single saved vocabulary file if and only if the tokenizer only requires a
single vocabulary file (like Bert or XLNet), e.g.:
./my_model_directory/vocab.txt
. (Not applicable to all derived classes)
- inputs (additional positional arguments, optional) —
Will be passed along to the Tokenizer
__init__()
method. - config (PretrainedConfig, optional) — The configuration object used to determine the tokenizer class to instantiate.
- cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download the model weights and configuration files and override the cached versions if they exist. - resume_download — Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. - revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - subfolder (
str
, optional) — In case the relevant files are located inside a subfolder of the model repo on huggingface.co (e.g. for facebook/rag-token-base), specify it here. - use_fast (
bool
, optional, defaults toTrue
) — Use a fast Rust-based tokenizer if it is supported for a given model. If a fast tokenizer is not available for a given model, a normal Python-based tokenizer is returned instead. - tokenizer_type (
str
, optional) — Tokenizer type to be loaded. - trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. - kwargs (additional keyword arguments, optional) —
Will be passed to the Tokenizer
__init__()
method. Can be used to set special tokens likebos_token
,eos_token
,unk_token
,sep_token
,pad_token
,cls_token
,mask_token
,additional_special_tokens
. See parameters in the__init__()
for more details.
Instantiate one of the tokenizer classes of the library from a pretrained model vocabulary.
The tokenizer class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when it’s missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- albert —
AlbertTokenizer
orAlbertTokenizerFast
(ALBERT model) - align — BertTokenizer or BertTokenizerFast (ALIGN model)
- aria — LlamaTokenizer or LlamaTokenizerFast (Aria model)
- bark — BertTokenizer or BertTokenizerFast (Bark model)
- bart — BartTokenizer or BartTokenizerFast (BART model)
- barthez — BarthezTokenizer or BarthezTokenizerFast (BARThez model)
- bartpho — BartphoTokenizer (BARTpho model)
- bert — BertTokenizer or BertTokenizerFast (BERT model)
- bert-generation —
BertGenerationTokenizer
(Bert Generation model) - bert-japanese — BertJapaneseTokenizer (BertJapanese model)
- bertweet — BertweetTokenizer (BERTweet model)
- big_bird —
BigBirdTokenizer
orBigBirdTokenizerFast
(BigBird model) - bigbird_pegasus —
PegasusTokenizer
orPegasusTokenizerFast
(BigBird-Pegasus model) - biogpt — BioGptTokenizer (BioGpt model)
- blenderbot —
BlenderbotTokenizer
orBlenderbotTokenizerFast
(Blenderbot model) - blenderbot-small —
BlenderbotSmallTokenizer
(BlenderbotSmall model) - blip — BertTokenizer or BertTokenizerFast (BLIP model)
- blip-2 —
GPT2Tokenizer
orGPT2TokenizerFast
(BLIP-2 model) - bloom —
BloomTokenizerFast
(BLOOM model) - bridgetower —
RobertaTokenizer
orRobertaTokenizerFast
(BridgeTower model) - bros — BertTokenizer or BertTokenizerFast (BROS model)
- byt5 —
ByT5Tokenizer
(ByT5 model) - camembert —
CamembertTokenizer
orCamembertTokenizerFast
(CamemBERT model) - canine —
CanineTokenizer
(CANINE model) - chameleon — LlamaTokenizer or LlamaTokenizerFast (Chameleon model)
- chinese_clip — BertTokenizer or BertTokenizerFast (Chinese-CLIP model)
- clap —
RobertaTokenizer
orRobertaTokenizerFast
(CLAP model) - clip — CLIPTokenizer or CLIPTokenizerFast (CLIP model)
- clipseg — CLIPTokenizer or CLIPTokenizerFast (CLIPSeg model)
- clvp —
ClvpTokenizer
(CLVP model) - code_llama —
CodeLlamaTokenizer
orCodeLlamaTokenizerFast
(CodeLlama model) - codegen —
CodeGenTokenizer
orCodeGenTokenizerFast
(CodeGen model) - cohere — CohereTokenizerFast (Cohere model)
- cohere2 — CohereTokenizerFast (Cohere2 model)
- colpali — LlamaTokenizer or LlamaTokenizerFast (ColPali model)
- convbert — ConvBertTokenizer or ConvBertTokenizerFast (ConvBERT model)
- cpm —
CpmTokenizer
orCpmTokenizerFast
(CPM model) - cpmant —
CpmAntTokenizer
(CPM-Ant model) - ctrl —
CTRLTokenizer
(CTRL model) - data2vec-audio —
Wav2Vec2CTCTokenizer
(Data2VecAudio model) - data2vec-text —
RobertaTokenizer
orRobertaTokenizerFast
(Data2VecText model) - dbrx —
GPT2Tokenizer
orGPT2TokenizerFast
(DBRX model) - deberta — DebertaTokenizer or DebertaTokenizerFast (DeBERTa model)
- deberta-v2 — DebertaV2Tokenizer or DebertaV2TokenizerFast (DeBERTa-v2 model)
- diffllama — LlamaTokenizer or LlamaTokenizerFast (DiffLlama model)
- distilbert —
DistilBertTokenizer
orDistilBertTokenizerFast
(DistilBERT model) - dpr —
DPRQuestionEncoderTokenizer
orDPRQuestionEncoderTokenizerFast
(DPR model) - electra —
ElectraTokenizer
orElectraTokenizerFast
(ELECTRA model) - emu3 —
GPT2Tokenizer
orGPT2TokenizerFast
(Emu3 model) - ernie — BertTokenizer or BertTokenizerFast (ERNIE model)
- ernie_m —
ErnieMTokenizer
(ErnieM model) - esm — EsmTokenizer (ESM model)
- falcon —
PreTrainedTokenizerFast
(Falcon model) - falcon_mamba —
GPTNeoXTokenizerFast
(FalconMamba model) - fastspeech2_conformer — (FastSpeech2Conformer model)
- flaubert —
FlaubertTokenizer
(FlauBERT model) - fnet —
FNetTokenizer
orFNetTokenizerFast
(FNet model) - fsmt —
FSMTTokenizer
(FairSeq Machine-Translation model) - funnel —
FunnelTokenizer
orFunnelTokenizerFast
(Funnel Transformer model) - gemma — GemmaTokenizer or GemmaTokenizerFast (Gemma model)
- gemma2 — GemmaTokenizer or GemmaTokenizerFast (Gemma2 model)
- git — BertTokenizer or BertTokenizerFast (GIT model)
- glm —
PreTrainedTokenizerFast
(GLM model) - gpt-sw3 —
GPTSw3Tokenizer
(GPT-Sw3 model) - gpt2 —
GPT2Tokenizer
orGPT2TokenizerFast
(OpenAI GPT-2 model) - gpt_bigcode —
GPT2Tokenizer
orGPT2TokenizerFast
(GPTBigCode model) - gpt_neo —
GPT2Tokenizer
orGPT2TokenizerFast
(GPT Neo model) - gpt_neox —
GPTNeoXTokenizerFast
(GPT NeoX model) - gpt_neox_japanese — GPTNeoXJapaneseTokenizer (GPT NeoX Japanese model)
- gptj —
GPT2Tokenizer
orGPT2TokenizerFast
(GPT-J model) - gptsan-japanese —
GPTSanJapaneseTokenizer
(GPTSAN-japanese model) - grounding-dino — BertTokenizer or BertTokenizerFast (Grounding DINO model)
- groupvit — CLIPTokenizer or CLIPTokenizerFast (GroupViT model)
- helium —
PreTrainedTokenizerFast
(Helium model) - herbert —
HerbertTokenizer
orHerbertTokenizerFast
(HerBERT model) - hubert —
Wav2Vec2CTCTokenizer
(Hubert model) - ibert —
RobertaTokenizer
orRobertaTokenizerFast
(I-BERT model) - idefics — LlamaTokenizerFast (IDEFICS model)
- idefics2 — LlamaTokenizer or LlamaTokenizerFast (Idefics2 model)
- idefics3 — LlamaTokenizer or LlamaTokenizerFast (Idefics3 model)
- instructblip —
GPT2Tokenizer
orGPT2TokenizerFast
(InstructBLIP model) - instructblipvideo —
GPT2Tokenizer
orGPT2TokenizerFast
(InstructBlipVideo model) - jamba — LlamaTokenizer or LlamaTokenizerFast (Jamba model)
- jetmoe — LlamaTokenizer or LlamaTokenizerFast (JetMoe model)
- jukebox —
JukeboxTokenizer
(Jukebox model) - kosmos-2 —
XLMRobertaTokenizer
orXLMRobertaTokenizerFast
(KOSMOS-2 model) - layoutlm —
LayoutLMTokenizer
orLayoutLMTokenizerFast
(LayoutLM model) - layoutlmv2 —
LayoutLMv2Tokenizer
orLayoutLMv2TokenizerFast
(LayoutLMv2 model) - layoutlmv3 —
LayoutLMv3Tokenizer
orLayoutLMv3TokenizerFast
(LayoutLMv3 model) - layoutxlm —
LayoutXLMTokenizer
orLayoutXLMTokenizerFast
(LayoutXLM model) - led —
LEDTokenizer
orLEDTokenizerFast
(LED model) - lilt —
LayoutLMv3Tokenizer
orLayoutLMv3TokenizerFast
(LiLT model) - llama — LlamaTokenizer or LlamaTokenizerFast (LLaMA model)
- llava — LlamaTokenizer or LlamaTokenizerFast (LLaVa model)
- llava_next — LlamaTokenizer or LlamaTokenizerFast (LLaVA-NeXT model)
- llava_next_video — LlamaTokenizer or LlamaTokenizerFast (LLaVa-NeXT-Video model)
- llava_onevision — LlamaTokenizer or LlamaTokenizerFast (LLaVA-Onevision model)
- longformer —
LongformerTokenizer
orLongformerTokenizerFast
(Longformer model) - longt5 —
T5Tokenizer
orT5TokenizerFast
(LongT5 model) - luke —
LukeTokenizer
(LUKE model) - lxmert —
LxmertTokenizer
orLxmertTokenizerFast
(LXMERT model) - m2m_100 —
M2M100Tokenizer
(M2M100 model) - mamba —
GPTNeoXTokenizerFast
(Mamba model) - mamba2 —
GPTNeoXTokenizerFast
(mamba2 model) - marian — MarianTokenizer (Marian model)
- mbart —
MBartTokenizer
orMBartTokenizerFast
(mBART model) - mbart50 —
MBart50Tokenizer
orMBart50TokenizerFast
(mBART-50 model) - mega —
RobertaTokenizer
orRobertaTokenizerFast
(MEGA model) - megatron-bert — BertTokenizer or BertTokenizerFast (Megatron-BERT model)
- mgp-str —
MgpstrTokenizer
(MGP-STR model) - mistral — LlamaTokenizer or LlamaTokenizerFast (Mistral model)
- mixtral — LlamaTokenizer or LlamaTokenizerFast (Mixtral model)
- mllama — LlamaTokenizer or LlamaTokenizerFast (Mllama model)
- mluke —
MLukeTokenizer
(mLUKE model) - mobilebert —
MobileBertTokenizer
orMobileBertTokenizerFast
(MobileBERT model) - modernbert —
PreTrainedTokenizerFast
(ModernBERT model) - moonshine —
PreTrainedTokenizerFast
(Moonshine model) - moshi —
PreTrainedTokenizerFast
(Moshi model) - mpnet —
MPNetTokenizer
orMPNetTokenizerFast
(MPNet model) - mpt —
GPTNeoXTokenizerFast
(MPT model) - mra —
RobertaTokenizer
orRobertaTokenizerFast
(MRA model) - mt5 —
MT5Tokenizer
orMT5TokenizerFast
(MT5 model) - musicgen —
T5Tokenizer
orT5TokenizerFast
(MusicGen model) - musicgen_melody —
T5Tokenizer
orT5TokenizerFast
(MusicGen Melody model) - mvp —
MvpTokenizer
orMvpTokenizerFast
(MVP model) - myt5 —
MyT5Tokenizer
(myt5 model) - nezha — BertTokenizer or BertTokenizerFast (Nezha model)
- nllb —
NllbTokenizer
orNllbTokenizerFast
(NLLB model) - nllb-moe —
NllbTokenizer
orNllbTokenizerFast
(NLLB-MOE model) - nystromformer —
AlbertTokenizer
orAlbertTokenizerFast
(Nyströmformer model) - olmo —
GPTNeoXTokenizerFast
(OLMo model) - olmo2 —
GPTNeoXTokenizerFast
(OLMo2 model) - olmoe —
GPTNeoXTokenizerFast
(OLMoE model) - omdet-turbo — CLIPTokenizer or CLIPTokenizerFast (OmDet-Turbo model)
- oneformer — CLIPTokenizer or CLIPTokenizerFast (OneFormer model)
- openai-gpt — OpenAIGPTTokenizer or OpenAIGPTTokenizerFast (OpenAI GPT model)
- opt —
GPT2Tokenizer
orGPT2TokenizerFast
(OPT model) - owlv2 — CLIPTokenizer or CLIPTokenizerFast (OWLv2 model)
- owlvit — CLIPTokenizer or CLIPTokenizerFast (OWL-ViT model)
- paligemma — LlamaTokenizer or LlamaTokenizerFast (PaliGemma model)
- pegasus —
PegasusTokenizer
orPegasusTokenizerFast
(Pegasus model) - pegasus_x —
PegasusTokenizer
orPegasusTokenizerFast
(PEGASUS-X model) - perceiver —
PerceiverTokenizer
(Perceiver model) - persimmon — LlamaTokenizer or LlamaTokenizerFast (Persimmon model)
- phi —
CodeGenTokenizer
orCodeGenTokenizerFast
(Phi model) - phi3 — LlamaTokenizer or LlamaTokenizerFast (Phi3 model)
- phimoe — LlamaTokenizer or LlamaTokenizerFast (Phimoe model)
- phobert —
PhobertTokenizer
(PhoBERT model) - pix2struct —
T5Tokenizer
orT5TokenizerFast
(Pix2Struct model) - pixtral —
PreTrainedTokenizerFast
(Pixtral model) - plbart —
PLBartTokenizer
(PLBart model) - prophetnet —
ProphetNetTokenizer
(ProphetNet model) - qdqbert — BertTokenizer or BertTokenizerFast (QDQBert model)
- qwen2 —
Qwen2Tokenizer
orQwen2TokenizerFast
(Qwen2 model) - qwen2_audio —
Qwen2Tokenizer
orQwen2TokenizerFast
(Qwen2Audio model) - qwen2_moe —
Qwen2Tokenizer
orQwen2TokenizerFast
(Qwen2MoE model) - qwen2_vl —
Qwen2Tokenizer
orQwen2TokenizerFast
(Qwen2VL model) - rag — RagTokenizer (RAG model)
- realm —
RealmTokenizer
orRealmTokenizerFast
(REALM model) - recurrent_gemma — GemmaTokenizer or GemmaTokenizerFast (RecurrentGemma model)
- reformer —
ReformerTokenizer
orReformerTokenizerFast
(Reformer model) - rembert —
RemBertTokenizer
orRemBertTokenizerFast
(RemBERT model) - retribert —
RetriBertTokenizer
orRetriBertTokenizerFast
(RetriBERT model) - roberta —
RobertaTokenizer
orRobertaTokenizerFast
(RoBERTa model) - roberta-prelayernorm —
RobertaTokenizer
orRobertaTokenizerFast
(RoBERTa-PreLayerNorm model) - roc_bert —
RoCBertTokenizer
(RoCBert model) - roformer —
RoFormerTokenizer
orRoFormerTokenizerFast
(RoFormer model) - rwkv —
GPTNeoXTokenizerFast
(RWKV model) - seamless_m4t —
SeamlessM4TTokenizer
orSeamlessM4TTokenizerFast
(SeamlessM4T model) - seamless_m4t_v2 —
SeamlessM4TTokenizer
orSeamlessM4TTokenizerFast
(SeamlessM4Tv2 model) - siglip —
SiglipTokenizer
(SigLIP model) - speech_to_text —
Speech2TextTokenizer
(Speech2Text model) - speech_to_text_2 —
Speech2Text2Tokenizer
(Speech2Text2 model) - speecht5 —
SpeechT5Tokenizer
(SpeechT5 model) - splinter —
SplinterTokenizer
orSplinterTokenizerFast
(Splinter model) - squeezebert —
SqueezeBertTokenizer
orSqueezeBertTokenizerFast
(SqueezeBERT model) - stablelm —
GPTNeoXTokenizerFast
(StableLm model) - starcoder2 —
GPT2Tokenizer
orGPT2TokenizerFast
(Starcoder2 model) - switch_transformers —
T5Tokenizer
orT5TokenizerFast
(SwitchTransformers model) - t5 —
T5Tokenizer
orT5TokenizerFast
(T5 model) - tapas —
TapasTokenizer
(TAPAS model) - tapex —
TapexTokenizer
(TAPEX model) - transfo-xl —
TransfoXLTokenizer
(Transformer-XL model) - tvp — BertTokenizer or BertTokenizerFast (TVP model)
- udop —
UdopTokenizer
orUdopTokenizerFast
(UDOP model) - umt5 —
T5Tokenizer
orT5TokenizerFast
(UMT5 model) - video_llava — LlamaTokenizer or LlamaTokenizerFast (VideoLlava model)
- vilt — BertTokenizer or BertTokenizerFast (ViLT model)
- vipllava — LlamaTokenizer or LlamaTokenizerFast (VipLlava model)
- visual_bert — BertTokenizer or BertTokenizerFast (VisualBERT model)
- vits —
VitsTokenizer
(VITS model) - wav2vec2 —
Wav2Vec2CTCTokenizer
(Wav2Vec2 model) - wav2vec2-bert —
Wav2Vec2CTCTokenizer
(Wav2Vec2-BERT model) - wav2vec2-conformer —
Wav2Vec2CTCTokenizer
(Wav2Vec2-Conformer model) - wav2vec2_phoneme —
Wav2Vec2PhonemeCTCTokenizer
(Wav2Vec2Phoneme model) - whisper — WhisperTokenizer or WhisperTokenizerFast (Whisper model)
- xclip — CLIPTokenizer or CLIPTokenizerFast (X-CLIP model)
- xglm —
XGLMTokenizer
orXGLMTokenizerFast
(XGLM model) - xlm —
XLMTokenizer
(XLM model) - xlm-prophetnet —
XLMProphetNetTokenizer
(XLM-ProphetNet model) - xlm-roberta —
XLMRobertaTokenizer
orXLMRobertaTokenizerFast
(XLM-RoBERTa model) - xlm-roberta-xl —
XLMRobertaTokenizer
orXLMRobertaTokenizerFast
(XLM-RoBERTa-XL model) - xlnet —
XLNetTokenizer
orXLNetTokenizerFast
(XLNet model) - xmod —
XLMRobertaTokenizer
orXLMRobertaTokenizerFast
(X-MOD model) - yoso —
AlbertTokenizer
orAlbertTokenizerFast
(YOSO model) - zamba — LlamaTokenizer or LlamaTokenizerFast (Zamba model)
Examples:
>>> from transformers import AutoTokenizer
>>> # Download vocabulary from huggingface.co and cache.
>>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
>>> # Download vocabulary from huggingface.co (user-uploaded) and cache.
>>> tokenizer = AutoTokenizer.from_pretrained("dbmdz/bert-base-german-cased")
>>> # If vocabulary files are in a directory (e.g. tokenizer was saved using *save_pretrained('./test/saved_model/')*)
>>> # tokenizer = AutoTokenizer.from_pretrained("./test/bert_saved_model/")
>>> # Download vocabulary from huggingface.co and define model-specific arguments
>>> tokenizer = AutoTokenizer.from_pretrained("FacebookAI/roberta-base", add_prefix_space=True)
register
< source >( config_class slow_tokenizer_class = None fast_tokenizer_class = None exist_ok = False )
Parameters
- config_class (PretrainedConfig) — The configuration corresponding to the model to register.
- slow_tokenizer_class (
PretrainedTokenizer
, optional) — The slow tokenizer to register. - fast_tokenizer_class (
PretrainedTokenizerFast
, optional) — The fast tokenizer to register.
Register a new tokenizer in this mapping.
AutoFeatureExtractor
This is a generic feature extractor class that will be instantiated as one of the feature extractor classes of the library when created with the AutoFeatureExtractor.from_pretrained() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_pretrained
< source >( pretrained_model_name_or_path **kwargs )
Parameters
- pretrained_model_name_or_path (
str
oros.PathLike
) — This can be either:- a string, the model id of a pretrained feature_extractor hosted inside a model repo on huggingface.co.
- a path to a directory containing a feature extractor file saved using the
save_pretrained() method, e.g.,
./my_model_directory/
. - a path or url to a saved feature extractor JSON file, e.g.,
./my_model_directory/preprocessor_config.json
.
- cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model feature extractor should be cached if the standard cache should not be used. - force_download (
bool
, optional, defaults toFalse
) — Whether or not to force to (re-)download the feature extractor files and override the cached versions if they exist. - resume_download — Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request. - token (
str
or bool, optional) — The token to use as HTTP bearer authorization for remote files. IfTrue
, will use the token generated when runninghuggingface-cli login
(stored in~/.huggingface
). - revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - return_unused_kwargs (
bool
, optional, defaults toFalse
) — IfFalse
, then this function returns just the final feature extractor object. IfTrue
, then this functions returns aTuple(feature_extractor, unused_kwargs)
where unused_kwargs is a dictionary consisting of the key/value pairs whose keys are not feature extractor attributes: i.e., the part ofkwargs
which has not been used to updatefeature_extractor
and is otherwise ignored. - trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. - kwargs (
Dict[str, Any]
, optional) — The values in kwargs of any keys which are feature extractor attributes will be used to override the loaded values. Behavior concerning key/value pairs whose keys are not feature extractor attributes is controlled by thereturn_unused_kwargs
keyword parameter.
Instantiate one of the feature extractor classes of the library from a pretrained model vocabulary.
The feature extractor class to instantiate is selected based on the model_type
property of the config object
(either passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when it’s
missing, by falling back to using pattern matching on pretrained_model_name_or_path
:
- audio-spectrogram-transformer —
ASTFeatureExtractor
(Audio Spectrogram Transformer model) - beit —
BeitFeatureExtractor
(BEiT model) - chinese_clip —
ChineseCLIPFeatureExtractor
(Chinese-CLIP model) - clap —
ClapFeatureExtractor
(CLAP model) - clip — CLIPFeatureExtractor (CLIP model)
- clipseg — ViTFeatureExtractor (CLIPSeg model)
- clvp —
ClvpFeatureExtractor
(CLVP model) - conditional_detr —
ConditionalDetrFeatureExtractor
(Conditional DETR model) - convnext —
ConvNextFeatureExtractor
(ConvNeXT model) - cvt —
ConvNextFeatureExtractor
(CvT model) - dac —
DacFeatureExtractor
(DAC model) - data2vec-audio —
Wav2Vec2FeatureExtractor
(Data2VecAudio model) - data2vec-vision —
BeitFeatureExtractor
(Data2VecVision model) - deformable_detr —
DeformableDetrFeatureExtractor
(Deformable DETR model) - deit —
DeiTFeatureExtractor
(DeiT model) - detr —
DetrFeatureExtractor
(DETR model) - dinat — ViTFeatureExtractor (DiNAT model)
- donut-swin —
DonutFeatureExtractor
(DonutSwin model) - dpt —
DPTFeatureExtractor
(DPT model) - encodec —
EncodecFeatureExtractor
(EnCodec model) - flava —
FlavaFeatureExtractor
(FLAVA model) - glpn —
GLPNFeatureExtractor
(GLPN model) - groupvit — CLIPFeatureExtractor (GroupViT model)
- hubert —
Wav2Vec2FeatureExtractor
(Hubert model) - imagegpt —
ImageGPTFeatureExtractor
(ImageGPT model) - layoutlmv2 —
LayoutLMv2FeatureExtractor
(LayoutLMv2 model) - layoutlmv3 —
LayoutLMv3FeatureExtractor
(LayoutLMv3 model) - levit —
LevitFeatureExtractor
(LeViT model) - maskformer —
MaskFormerFeatureExtractor
(MaskFormer model) - mctct —
MCTCTFeatureExtractor
(M-CTC-T model) - mimi —
EncodecFeatureExtractor
(Mimi model) - mobilenet_v1 —
MobileNetV1FeatureExtractor
(MobileNetV1 model) - mobilenet_v2 —
MobileNetV2FeatureExtractor
(MobileNetV2 model) - mobilevit —
MobileViTFeatureExtractor
(MobileViT model) - moonshine —
Wav2Vec2FeatureExtractor
(Moonshine model) - moshi —
EncodecFeatureExtractor
(Moshi model) - nat — ViTFeatureExtractor (NAT model)
- owlvit —
OwlViTFeatureExtractor
(OWL-ViT model) - perceiver —
PerceiverFeatureExtractor
(Perceiver model) - poolformer —
PoolFormerFeatureExtractor
(PoolFormer model) - pop2piano —
Pop2PianoFeatureExtractor
(Pop2Piano model) - regnet —
ConvNextFeatureExtractor
(RegNet model) - resnet —
ConvNextFeatureExtractor
(ResNet model) - seamless_m4t —
SeamlessM4TFeatureExtractor
(SeamlessM4T model) - seamless_m4t_v2 —
SeamlessM4TFeatureExtractor
(SeamlessM4Tv2 model) - segformer —
SegformerFeatureExtractor
(SegFormer model) - sew —
Wav2Vec2FeatureExtractor
(SEW model) - sew-d —
Wav2Vec2FeatureExtractor
(SEW-D model) - speech_to_text —
Speech2TextFeatureExtractor
(Speech2Text model) - speecht5 —
SpeechT5FeatureExtractor
(SpeechT5 model) - swiftformer — ViTFeatureExtractor (SwiftFormer model)
- swin — ViTFeatureExtractor (Swin Transformer model)
- swinv2 — ViTFeatureExtractor (Swin Transformer V2 model)
- table-transformer —
DetrFeatureExtractor
(Table Transformer model) - timesformer —
VideoMAEFeatureExtractor
(TimeSformer model) - tvlt —
TvltFeatureExtractor
(TVLT model) - unispeech —
Wav2Vec2FeatureExtractor
(UniSpeech model) - unispeech-sat —
Wav2Vec2FeatureExtractor
(UniSpeechSat model) - univnet —
UnivNetFeatureExtractor
(UnivNet model) - van —
ConvNextFeatureExtractor
(VAN model) - videomae —
VideoMAEFeatureExtractor
(VideoMAE model) - vilt —
ViltFeatureExtractor
(ViLT model) - vit — ViTFeatureExtractor (ViT model)
- vit_mae — ViTFeatureExtractor (ViTMAE model)
- vit_msn — ViTFeatureExtractor (ViTMSN model)
- wav2vec2 —
Wav2Vec2FeatureExtractor
(Wav2Vec2 model) - wav2vec2-bert —
Wav2Vec2FeatureExtractor
(Wav2Vec2-BERT model) - wav2vec2-conformer —
Wav2Vec2FeatureExtractor
(Wav2Vec2-Conformer model) - wavlm —
Wav2Vec2FeatureExtractor
(WavLM model) - whisper — WhisperFeatureExtractor (Whisper model)
- xclip — CLIPFeatureExtractor (X-CLIP model)
- yolos —
YolosFeatureExtractor
(YOLOS model)
Passing token=True
is required when you want to use a private model.
Examples:
>>> from transformers import AutoFeatureExtractor
>>> # Download feature extractor from huggingface.co and cache.
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")
>>> # If feature extractor files are in a directory (e.g. feature extractor was saved using *save_pretrained('./test/saved_model/')*)
>>> # feature_extractor = AutoFeatureExtractor.from_pretrained("./test/saved_model/")
register
< source >( config_class feature_extractor_class exist_ok = False )
Parameters
- config_class (PretrainedConfig) — The configuration corresponding to the model to register.
- feature_extractor_class (
FeatureExtractorMixin
) — The feature extractor to register.
Register a new feature extractor for this class.
AutoImageProcessor
This is a generic image processor class that will be instantiated as one of the image processor classes of the library when created with the AutoImageProcessor.from_pretrained() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_pretrained
< source >( pretrained_model_name_or_path *inputs **kwargs )
Parameters
- pretrained_model_name_or_path (
str
oros.PathLike
) — This can be either:- a string, the model id of a pretrained image_processor hosted inside a model repo on huggingface.co.
- a path to a directory containing a image processor file saved using the
save_pretrained() method, e.g.,
./my_model_directory/
. - a path or url to a saved image processor JSON file, e.g.,
./my_model_directory/preprocessor_config.json
.
- cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model image processor should be cached if the standard cache should not be used. - force_download (
bool
, optional, defaults toFalse
) — Whether or not to force to (re-)download the image processor files and override the cached versions if they exist. - resume_download — Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request. - token (
str
or bool, optional) — The token to use as HTTP bearer authorization for remote files. IfTrue
, will use the token generated when runninghuggingface-cli login
(stored in~/.huggingface
). - revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - use_fast (
bool
, optional, defaults toFalse
) — Use a fast torchvision-base image processor if it is supported for a given model. If a fast image processor is not available for a given model, a normal numpy-based image processor is returned instead. - return_unused_kwargs (
bool
, optional, defaults toFalse
) — IfFalse
, then this function returns just the final image processor object. IfTrue
, then this functions returns aTuple(image_processor, unused_kwargs)
where unused_kwargs is a dictionary consisting of the key/value pairs whose keys are not image processor attributes: i.e., the part ofkwargs
which has not been used to updateimage_processor
and is otherwise ignored. - trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. - image_processor_filename (
str
, optional, defaults to"config.json"
) — The name of the file in the model directory to use for the image processor config. - kwargs (
Dict[str, Any]
, optional) — The values in kwargs of any keys which are image processor attributes will be used to override the loaded values. Behavior concerning key/value pairs whose keys are not image processor attributes is controlled by thereturn_unused_kwargs
keyword parameter.
Instantiate one of the image processor classes of the library from a pretrained model vocabulary.
The image processor class to instantiate is selected based on the model_type
property of the config object
(either passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when it’s
missing, by falling back to using pattern matching on pretrained_model_name_or_path
:
- align —
EfficientNetImageProcessor
(ALIGN model) - aria —
A
orr
(Aria model) - beit —
BeitImageProcessor
(BEiT model) - bit —
BitImageProcessor
(BiT model) - blip — BlipImageProcessor (BLIP model)
- blip-2 — BlipImageProcessor (BLIP-2 model)
- bridgetower —
BridgeTowerImageProcessor
(BridgeTower model) - chameleon — ChameleonImageProcessor (Chameleon model)
- chinese_clip —
ChineseCLIPImageProcessor
(Chinese-CLIP model) - clip — CLIPImageProcessor (CLIP model)
- clipseg — ViTImageProcessor or ViTImageProcessorFast (CLIPSeg model)
- conditional_detr —
ConditionalDetrImageProcessor
(Conditional DETR model) - convnext —
ConvNextImageProcessor
(ConvNeXT model) - convnextv2 —
ConvNextImageProcessor
(ConvNeXTV2 model) - cvt —
ConvNextImageProcessor
(CvT model) - data2vec-vision —
BeitImageProcessor
(Data2VecVision model) - deformable_detr —
DeformableDetrImageProcessor
orDeformableDetrImageProcessorFast
(Deformable DETR model) - deit —
DeiTImageProcessor
(DeiT model) - depth_anything —
DPTImageProcessor
(Depth Anything model) - deta —
DetaImageProcessor
(DETA model) - detr —
DetrImageProcessor
orDetrImageProcessorFast
(DETR model) - dinat — ViTImageProcessor or ViTImageProcessorFast (DiNAT model)
- dinov2 —
BitImageProcessor
(DINOv2 model) - donut-swin —
DonutImageProcessor
(DonutSwin model) - dpt —
DPTImageProcessor
(DPT model) - efficientformer —
EfficientFormerImageProcessor
(EfficientFormer model) - efficientnet —
EfficientNetImageProcessor
(EfficientNet model) - flava —
FlavaImageProcessor
(FLAVA model) - focalnet —
BitImageProcessor
(FocalNet model) - fuyu —
FuyuImageProcessor
(Fuyu model) - git — CLIPImageProcessor (GIT model)
- glpn —
GLPNImageProcessor
(GLPN model) - grounding-dino —
GroundingDinoImageProcessor
(Grounding DINO model) - groupvit — CLIPImageProcessor (GroupViT model)
- hiera —
BitImageProcessor
(Hiera model) - idefics —
IdeficsImageProcessor
(IDEFICS model) - idefics2 —
Idefics2ImageProcessor
(Idefics2 model) - idefics3 —
Idefics3ImageProcessor
(Idefics3 model) - ijepa — ViTImageProcessor or ViTImageProcessorFast (I-JEPA model)
- imagegpt —
ImageGPTImageProcessor
(ImageGPT model) - instructblip — BlipImageProcessor (InstructBLIP model)
- instructblipvideo —
InstructBlipVideoImageProcessor
(InstructBlipVideo model) - kosmos-2 — CLIPImageProcessor (KOSMOS-2 model)
- layoutlmv2 —
LayoutLMv2ImageProcessor
(LayoutLMv2 model) - layoutlmv3 —
LayoutLMv3ImageProcessor
(LayoutLMv3 model) - levit —
LevitImageProcessor
(LeViT model) - llava — CLIPImageProcessor (LLaVa model)
- llava_next —
LlavaNextImageProcessor
(LLaVA-NeXT model) - llava_next_video —
LlavaNextVideoImageProcessor
(LLaVa-NeXT-Video model) - llava_onevision —
LlavaOnevisionImageProcessor
(LLaVA-Onevision model) - mask2former —
Mask2FormerImageProcessor
(Mask2Former model) - maskformer —
MaskFormerImageProcessor
(MaskFormer model) - mgp-str — ViTImageProcessor or ViTImageProcessorFast (MGP-STR model)
- mllama —
MllamaImageProcessor
(Mllama model) - mobilenet_v1 —
MobileNetV1ImageProcessor
(MobileNetV1 model) - mobilenet_v2 —
MobileNetV2ImageProcessor
(MobileNetV2 model) - mobilevit —
MobileViTImageProcessor
(MobileViT model) - mobilevitv2 —
MobileViTImageProcessor
(MobileViTV2 model) - nat — ViTImageProcessor or ViTImageProcessorFast (NAT model)
- nougat —
NougatImageProcessor
(Nougat model) - oneformer —
OneFormerImageProcessor
(OneFormer model) - owlv2 —
Owlv2ImageProcessor
(OWLv2 model) - owlvit —
OwlViTImageProcessor
(OWL-ViT model) - paligemma —
SiglipImageProcessor
(PaliGemma model) - perceiver —
PerceiverImageProcessor
(Perceiver model) - pix2struct —
Pix2StructImageProcessor
(Pix2Struct model) - pixtral —
PixtralImageProcessor
orPixtralImageProcessorFast
(Pixtral model) - poolformer —
PoolFormerImageProcessor
(PoolFormer model) - pvt —
PvtImageProcessor
(PVT model) - pvt_v2 —
PvtImageProcessor
(PVTv2 model) - qwen2_vl —
Qwen2VLImageProcessor
(Qwen2VL model) - regnet —
ConvNextImageProcessor
(RegNet model) - resnet —
ConvNextImageProcessor
(ResNet model) - rt_detr —
RTDetrImageProcessor
orRTDetrImageProcessorFast
(RT-DETR model) - sam —
SamImageProcessor
(SAM model) - segformer —
SegformerImageProcessor
(SegFormer model) - seggpt —
SegGptImageProcessor
(SegGPT model) - siglip —
SiglipImageProcessor
(SigLIP model) - swiftformer — ViTImageProcessor or ViTImageProcessorFast (SwiftFormer model)
- swin — ViTImageProcessor or ViTImageProcessorFast (Swin Transformer model)
- swin2sr — Swin2SRImageProcessor (Swin2SR model)
- swinv2 — ViTImageProcessor or ViTImageProcessorFast (Swin Transformer V2 model)
- table-transformer —
DetrImageProcessor
(Table Transformer model) - timesformer —
VideoMAEImageProcessor
(TimeSformer model) - timm_wrapper —
TimmWrapperImageProcessor
(TimmWrapperModel model) - tvlt —
TvltImageProcessor
(TVLT model) - tvp —
TvpImageProcessor
(TVP model) - udop —
LayoutLMv3ImageProcessor
(UDOP model) - upernet —
SegformerImageProcessor
(UPerNet model) - van —
ConvNextImageProcessor
(VAN model) - videomae —
VideoMAEImageProcessor
(VideoMAE model) - vilt —
ViltImageProcessor
(ViLT model) - vipllava — CLIPImageProcessor (VipLlava model)
- vit — ViTImageProcessor or ViTImageProcessorFast (ViT model)
- vit_hybrid —
ViTHybridImageProcessor
(ViT Hybrid model) - vit_mae — ViTImageProcessor or ViTImageProcessorFast (ViTMAE model)
- vit_msn — ViTImageProcessor or ViTImageProcessorFast (ViTMSN model)
- vitmatte —
VitMatteImageProcessor
(ViTMatte model) - xclip — CLIPImageProcessor (X-CLIP model)
- yolos —
YolosImageProcessor
(YOLOS model) - zoedepth —
ZoeDepthImageProcessor
(ZoeDepth model)
Passing token=True
is required when you want to use a private model.
Examples:
>>> from transformers import AutoImageProcessor
>>> # Download image processor from huggingface.co and cache.
>>> image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
>>> # If image processor files are in a directory (e.g. image processor was saved using *save_pretrained('./test/saved_model/')*)
>>> # image_processor = AutoImageProcessor.from_pretrained("./test/saved_model/")
register
< source >( config_class image_processor_class = None slow_image_processor_class = None fast_image_processor_class = None exist_ok = False )
Parameters
- config_class (PretrainedConfig) — The configuration corresponding to the model to register.
- image_processor_class (ImageProcessingMixin) — The image processor to register.
Register a new image processor for this class.
AutoProcessor
This is a generic processor class that will be instantiated as one of the processor classes of the library when created with the AutoProcessor.from_pretrained() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_pretrained
< source >( pretrained_model_name_or_path **kwargs )
Parameters
- pretrained_model_name_or_path (
str
oros.PathLike
) — This can be either:- a string, the model id of a pretrained feature_extractor hosted inside a model repo on huggingface.co.
- a path to a directory containing a processor files saved using the
save_pretrained()
method, e.g.,./my_model_directory/
.
- cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model feature extractor should be cached if the standard cache should not be used. - force_download (
bool
, optional, defaults toFalse
) — Whether or not to force to (re-)download the feature extractor files and override the cached versions if they exist. - resume_download — Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request. - token (
str
or bool, optional) — The token to use as HTTP bearer authorization for remote files. IfTrue
, will use the token generated when runninghuggingface-cli login
(stored in~/.huggingface
). - revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - return_unused_kwargs (
bool
, optional, defaults toFalse
) — IfFalse
, then this function returns just the final feature extractor object. IfTrue
, then this functions returns aTuple(feature_extractor, unused_kwargs)
where unused_kwargs is a dictionary consisting of the key/value pairs whose keys are not feature extractor attributes: i.e., the part ofkwargs
which has not been used to updatefeature_extractor
and is otherwise ignored. - trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. - kwargs (
Dict[str, Any]
, optional) — The values in kwargs of any keys which are feature extractor attributes will be used to override the loaded values. Behavior concerning key/value pairs whose keys are not feature extractor attributes is controlled by thereturn_unused_kwargs
keyword parameter.
Instantiate one of the processor classes of the library from a pretrained model vocabulary.
The processor class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible):
- align —
AlignProcessor
(ALIGN model) - altclip — AltCLIPProcessor (AltCLIP model)
- aria —
AriaProcessor
(Aria model) - bark —
BarkProcessor
(Bark model) - blip — BlipProcessor (BLIP model)
- blip-2 — Blip2Processor (BLIP-2 model)
- bridgetower —
BridgeTowerProcessor
(BridgeTower model) - chameleon — ChameleonProcessor (Chameleon model)
- chinese_clip —
ChineseCLIPProcessor
(Chinese-CLIP model) - clap —
ClapProcessor
(CLAP model) - clip — CLIPProcessor (CLIP model)
- clipseg —
CLIPSegProcessor
(CLIPSeg model) - clvp —
ClvpProcessor
(CLVP model) - colpali —
ColPaliProcessor
(ColPali model) - emu3 —
Emu3Processor
(Emu3 model) - flava —
FlavaProcessor
(FLAVA model) - fuyu —
FuyuProcessor
(Fuyu model) - git —
GitProcessor
(GIT model) - grounding-dino —
GroundingDinoProcessor
(Grounding DINO model) - groupvit — CLIPProcessor (GroupViT model)
- hubert —
Wav2Vec2Processor
(Hubert model) - idefics —
IdeficsProcessor
(IDEFICS model) - idefics2 —
Idefics2Processor
(Idefics2 model) - idefics3 —
Idefics3Processor
(Idefics3 model) - instructblip —
InstructBlipProcessor
(InstructBLIP model) - instructblipvideo —
InstructBlipVideoProcessor
(InstructBlipVideo model) - kosmos-2 —
Kosmos2Processor
(KOSMOS-2 model) - layoutlmv2 —
LayoutLMv2Processor
(LayoutLMv2 model) - layoutlmv3 —
LayoutLMv3Processor
(LayoutLMv3 model) - llava —
LlavaProcessor
(LLaVa model) - llava_next —
LlavaNextProcessor
(LLaVA-NeXT model) - llava_next_video —
LlavaNextVideoProcessor
(LLaVa-NeXT-Video model) - llava_onevision —
LlavaOnevisionProcessor
(LLaVA-Onevision model) - markuplm —
MarkupLMProcessor
(MarkupLM model) - mctct —
MCTCTProcessor
(M-CTC-T model) - mgp-str —
MgpstrProcessor
(MGP-STR model) - mllama —
MllamaProcessor
(Mllama model) - moonshine —
Wav2Vec2Processor
(Moonshine model) - oneformer —
OneFormerProcessor
(OneFormer model) - owlv2 —
Owlv2Processor
(OWLv2 model) - owlvit —
OwlViTProcessor
(OWL-ViT model) - paligemma — PaliGemmaProcessor (PaliGemma model)
- pix2struct —
Pix2StructProcessor
(Pix2Struct model) - pixtral —
PixtralProcessor
(Pixtral model) - pop2piano —
Pop2PianoProcessor
(Pop2Piano model) - qwen2_audio —
Qwen2AudioProcessor
(Qwen2Audio model) - qwen2_vl —
Qwen2VLProcessor
(Qwen2VL model) - sam —
SamProcessor
(SAM model) - seamless_m4t —
SeamlessM4TProcessor
(SeamlessM4T model) - sew —
Wav2Vec2Processor
(SEW model) - sew-d —
Wav2Vec2Processor
(SEW-D model) - siglip —
SiglipProcessor
(SigLIP model) - speech_to_text —
Speech2TextProcessor
(Speech2Text model) - speech_to_text_2 —
Speech2Text2Processor
(Speech2Text2 model) - speecht5 —
SpeechT5Processor
(SpeechT5 model) - trocr —
TrOCRProcessor
(TrOCR model) - tvlt —
TvltProcessor
(TVLT model) - tvp —
TvpProcessor
(TVP model) - udop —
UdopProcessor
(UDOP model) - unispeech —
Wav2Vec2Processor
(UniSpeech model) - unispeech-sat —
Wav2Vec2Processor
(UniSpeechSat model) - video_llava —
VideoLlavaProcessor
(VideoLlava model) - vilt —
ViltProcessor
(ViLT model) - vipllava —
LlavaProcessor
(VipLlava model) - vision-text-dual-encoder —
VisionTextDualEncoderProcessor
(VisionTextDualEncoder model) - wav2vec2 —
Wav2Vec2Processor
(Wav2Vec2 model) - wav2vec2-bert —
Wav2Vec2Processor
(Wav2Vec2-BERT model) - wav2vec2-conformer —
Wav2Vec2Processor
(Wav2Vec2-Conformer model) - wavlm —
Wav2Vec2Processor
(WavLM model) - whisper — WhisperProcessor (Whisper model)
- xclip —
XCLIPProcessor
(X-CLIP model)
Passing token=True
is required when you want to use a private model.
Examples:
>>> from transformers import AutoProcessor
>>> # Download processor from huggingface.co and cache.
>>> processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h")
>>> # If processor files are in a directory (e.g. processor was saved using *save_pretrained('./test/saved_model/')*)
>>> # processor = AutoProcessor.from_pretrained("./test/saved_model/")
register
< source >( config_class processor_class exist_ok = False )
Parameters
- config_class (PretrainedConfig) — The configuration corresponding to the model to register.
- processor_class (
FeatureExtractorMixin
) — The processor to register.
Register a new processor for this class.
일반적인 모델 클래스
다음 자동 클래스들은 특정 헤드 없이 기본 모델 클래스를 인스턴스화하는 데 사용할 수 있습니다.
AutoModel
This is a generic model class that will be instantiated as one of the base model classes of the library when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
ASTConfig
configuration class:ASTModel
(Audio Spectrogram Transformer model)AlbertConfig
configuration class:AlbertModel
(ALBERT model)AlignConfig
configuration class:AlignModel
(ALIGN model)- AltCLIPConfig configuration class: AltCLIPModel (AltCLIP model)
AriaConfig
configuration class:AriaForConditionalGeneration
(Aria model)AriaTextConfig
configuration class:AriaTextModel
(AriaText model)- AutoformerConfig configuration class: AutoformerModel (Autoformer model)
BambaConfig
configuration class:BambaModel
(Bamba model)BarkConfig
configuration class:BarkModel
(Bark model)- BartConfig configuration class: BartModel (BART model)
BeitConfig
configuration class:BeitModel
(BEiT model)- BertConfig configuration class: BertModel (BERT model)
BertGenerationConfig
configuration class:BertGenerationEncoder
(Bert Generation model)BigBirdConfig
configuration class:BigBirdModel
(BigBird model)BigBirdPegasusConfig
configuration class:BigBirdPegasusModel
(BigBird-Pegasus model)- BioGptConfig configuration class: BioGptModel (BioGpt model)
BitConfig
configuration class:BitModel
(BiT model)BlenderbotConfig
configuration class:BlenderbotModel
(Blenderbot model)BlenderbotSmallConfig
configuration class:BlenderbotSmallModel
(BlenderbotSmall model)- Blip2Config configuration class: Blip2Model (BLIP-2 model)
- BlipConfig configuration class: BlipModel (BLIP model)
BloomConfig
configuration class:BloomModel
(BLOOM model)BridgeTowerConfig
configuration class:BridgeTowerModel
(BridgeTower model)BrosConfig
configuration class:BrosModel
(BROS model)- CLIPConfig configuration class: CLIPModel (CLIP model)
CLIPSegConfig
configuration class:CLIPSegModel
(CLIPSeg model)- CLIPTextConfig configuration class: CLIPTextModel (CLIPTextModel model)
- CLIPVisionConfig configuration class: CLIPVisionModel (CLIPVisionModel model)
CTRLConfig
configuration class:CTRLModel
(CTRL model)CamembertConfig
configuration class:CamembertModel
(CamemBERT model)CanineConfig
configuration class:CanineModel
(CANINE model)- ChameleonConfig configuration class: ChameleonModel (Chameleon model)
ChineseCLIPConfig
configuration class:ChineseCLIPModel
(Chinese-CLIP model)ChineseCLIPVisionConfig
configuration class:ChineseCLIPVisionModel
(ChineseCLIPVisionModel model)ClapConfig
configuration class:ClapModel
(CLAP model)ClvpConfig
configuration class:ClvpModelForConditionalGeneration
(CLVP model)CodeGenConfig
configuration class:CodeGenModel
(CodeGen model)Cohere2Config
configuration class:Cohere2Model
(Cohere2 model)- CohereConfig configuration class: CohereModel (Cohere model)
ConditionalDetrConfig
configuration class:ConditionalDetrModel
(Conditional DETR model)- ConvBertConfig configuration class: ConvBertModel (ConvBERT model)
ConvNextConfig
configuration class:ConvNextModel
(ConvNeXT model)ConvNextV2Config
configuration class:ConvNextV2Model
(ConvNeXTV2 model)CpmAntConfig
configuration class:CpmAntModel
(CPM-Ant model)CvtConfig
configuration class:CvtModel
(CvT model)DPRConfig
configuration class:DPRQuestionEncoder
(DPR model)DPTConfig
configuration class:DPTModel
(DPT model)DacConfig
configuration class:DacModel
(DAC model)Data2VecAudioConfig
configuration class:Data2VecAudioModel
(Data2VecAudio model)Data2VecTextConfig
configuration class:Data2VecTextModel
(Data2VecText model)Data2VecVisionConfig
configuration class:Data2VecVisionModel
(Data2VecVision model)- DbrxConfig configuration class: DbrxModel (DBRX model)
- DebertaConfig configuration class: DebertaModel (DeBERTa model)
- DebertaV2Config configuration class: DebertaV2Model (DeBERTa-v2 model)
DecisionTransformerConfig
configuration class:DecisionTransformerModel
(Decision Transformer model)DeformableDetrConfig
configuration class:DeformableDetrModel
(Deformable DETR model)DeiTConfig
configuration class:DeiTModel
(DeiT model)DetaConfig
configuration class:DetaModel
(DETA model)DetrConfig
configuration class:DetrModel
(DETR model)DiffLlamaConfig
configuration class:DiffLlamaModel
(DiffLlama model)DinatConfig
configuration class:DinatModel
(DiNAT model)Dinov2Config
configuration class:Dinov2Model
(DINOv2 model)Dinov2WithRegistersConfig
configuration class:Dinov2WithRegistersModel
(DINOv2 with Registers model)DistilBertConfig
configuration class:DistilBertModel
(DistilBERT model)DonutSwinConfig
configuration class:DonutSwinModel
(DonutSwin model)EfficientFormerConfig
configuration class:EfficientFormerModel
(EfficientFormer model)EfficientNetConfig
configuration class:EfficientNetModel
(EfficientNet model)ElectraConfig
configuration class:ElectraModel
(ELECTRA model)EncodecConfig
configuration class:EncodecModel
(EnCodec model)ErnieConfig
configuration class:ErnieModel
(ERNIE model)ErnieMConfig
configuration class:ErnieMModel
(ErnieM model)- EsmConfig configuration class: EsmModel (ESM model)
FNetConfig
configuration class:FNetModel
(FNet model)FSMTConfig
configuration class:FSMTModel
(FairSeq Machine-Translation model)FalconConfig
configuration class:FalconModel
(Falcon model)FalconMambaConfig
configuration class:FalconMambaModel
(FalconMamba model)FastSpeech2ConformerConfig
configuration class:FastSpeech2ConformerModel
(FastSpeech2Conformer model)FlaubertConfig
configuration class:FlaubertModel
(FlauBERT model)FlavaConfig
configuration class:FlavaModel
(FLAVA model)FocalNetConfig
configuration class:FocalNetModel
(FocalNet model)FunnelConfig
configuration class:FunnelModel
orFunnelBaseModel
(Funnel Transformer model)GLPNConfig
configuration class:GLPNModel
(GLPN model)GPT2Config
configuration class:GPT2Model
(OpenAI GPT-2 model)GPTBigCodeConfig
configuration class:GPTBigCodeModel
(GPTBigCode model)GPTJConfig
configuration class:GPTJModel
(GPT-J model)GPTNeoConfig
configuration class:GPTNeoModel
(GPT Neo model)GPTNeoXConfig
configuration class:GPTNeoXModel
(GPT NeoX model)- GPTNeoXJapaneseConfig configuration class: GPTNeoXJapaneseModel (GPT NeoX Japanese model)
GPTSanJapaneseConfig
configuration class:GPTSanJapaneseForConditionalGeneration
(GPTSAN-japanese model)- Gemma2Config configuration class: Gemma2Model (Gemma2 model)
- GemmaConfig configuration class: GemmaModel (Gemma model)
GitConfig
configuration class:GitModel
(GIT model)GlmConfig
configuration class:GlmModel
(GLM model)GraniteConfig
configuration class:GraniteModel
(Granite model)GraniteMoeConfig
configuration class:GraniteMoeModel
(GraniteMoeMoe model)- GraphormerConfig configuration class: GraphormerModel (Graphormer model)
GroundingDinoConfig
configuration class:GroundingDinoModel
(Grounding DINO model)GroupViTConfig
configuration class:GroupViTModel
(GroupViT model)HeliumConfig
configuration class:HeliumModel
(Helium model)HieraConfig
configuration class:HieraModel
(Hiera model)HubertConfig
configuration class:HubertModel
(Hubert model)IBertConfig
configuration class:IBertModel
(I-BERT model)IJepaConfig
configuration class:IJepaModel
(I-JEPA model)Idefics2Config
configuration class:Idefics2Model
(Idefics2 model)Idefics3Config
configuration class:Idefics3Model
(Idefics3 model)Idefics3VisionConfig
configuration class:Idefics3VisionTransformer
(Idefics3VisionTransformer model)IdeficsConfig
configuration class:IdeficsModel
(IDEFICS model)ImageGPTConfig
configuration class:ImageGPTModel
(ImageGPT model)- InformerConfig configuration class: InformerModel (Informer model)
JambaConfig
configuration class:JambaModel
(Jamba model)JetMoeConfig
configuration class:JetMoeModel
(JetMoe model)JukeboxConfig
configuration class:JukeboxModel
(Jukebox model)Kosmos2Config
configuration class:Kosmos2Model
(KOSMOS-2 model)LEDConfig
configuration class:LEDModel
(LED model)LayoutLMConfig
configuration class:LayoutLMModel
(LayoutLM model)LayoutLMv2Config
configuration class:LayoutLMv2Model
(LayoutLMv2 model)LayoutLMv3Config
configuration class:LayoutLMv3Model
(LayoutLMv3 model)LevitConfig
configuration class:LevitModel
(LeViT model)LiltConfig
configuration class:LiltModel
(LiLT model)- LlamaConfig configuration class: LlamaModel (LLaMA model)
LongT5Config
configuration class:LongT5Model
(LongT5 model)LongformerConfig
configuration class:LongformerModel
(Longformer model)LukeConfig
configuration class:LukeModel
(LUKE model)LxmertConfig
configuration class:LxmertModel
(LXMERT model)M2M100Config
configuration class:M2M100Model
(M2M100 model)MBartConfig
configuration class:MBartModel
(mBART model)MCTCTConfig
configuration class:MCTCTModel
(M-CTC-T model)MPNetConfig
configuration class:MPNetModel
(MPNet model)MT5Config
configuration class:MT5Model
(MT5 model)- Mamba2Config configuration class: Mamba2Model (mamba2 model)
- MambaConfig configuration class: MambaModel (Mamba model)
- MarianConfig configuration class: MarianModel (Marian model)
MarkupLMConfig
configuration class:MarkupLMModel
(MarkupLM model)Mask2FormerConfig
configuration class:Mask2FormerModel
(Mask2Former model)MaskFormerConfig
configuration class:MaskFormerModel
(MaskFormer model)MaskFormerSwinConfig
configuration class:MaskFormerSwinModel
(MaskFormerSwin model)MegaConfig
configuration class:MegaModel
(MEGA model)MegatronBertConfig
configuration class:MegatronBertModel
(Megatron-BERT model)MgpstrConfig
configuration class:MgpstrForSceneTextRecognition
(MGP-STR model)MimiConfig
configuration class:MimiModel
(Mimi model)- MistralConfig configuration class: MistralModel (Mistral model)
MixtralConfig
configuration class:MixtralModel
(Mixtral model)MobileBertConfig
configuration class:MobileBertModel
(MobileBERT model)MobileNetV1Config
configuration class:MobileNetV1Model
(MobileNetV1 model)MobileNetV2Config
configuration class:MobileNetV2Model
(MobileNetV2 model)MobileViTConfig
configuration class:MobileViTModel
(MobileViT model)MobileViTV2Config
configuration class:MobileViTV2Model
(MobileViTV2 model)ModernBertConfig
configuration class:ModernBertModel
(ModernBERT model)MoonshineConfig
configuration class:MoonshineModel
(Moonshine model)MoshiConfig
configuration class:MoshiModel
(Moshi model)MptConfig
configuration class:MptModel
(MPT model)MraConfig
configuration class:MraModel
(MRA model)MusicgenConfig
configuration class:MusicgenModel
(MusicGen model)MusicgenMelodyConfig
configuration class:MusicgenMelodyModel
(MusicGen Melody model)MvpConfig
configuration class:MvpModel
(MVP model)NatConfig
configuration class:NatModel
(NAT model)NemotronConfig
configuration class:NemotronModel
(Nemotron model)NezhaConfig
configuration class:NezhaModel
(Nezha model)NllbMoeConfig
configuration class:NllbMoeModel
(NLLB-MOE model)NystromformerConfig
configuration class:NystromformerModel
(Nyströmformer model)OPTConfig
configuration class:OPTModel
(OPT model)Olmo2Config
configuration class:Olmo2Model
(OLMo2 model)OlmoConfig
configuration class:OlmoModel
(OLMo model)OlmoeConfig
configuration class:OlmoeModel
(OLMoE model)OmDetTurboConfig
configuration class:OmDetTurboForObjectDetection
(OmDet-Turbo model)OneFormerConfig
configuration class:OneFormerModel
(OneFormer model)- OpenAIGPTConfig configuration class: OpenAIGPTModel (OpenAI GPT model)
OpenLlamaConfig
configuration class:OpenLlamaModel
(OpenLlama model)OwlViTConfig
configuration class:OwlViTModel
(OWL-ViT model)Owlv2Config
configuration class:Owlv2Model
(OWLv2 model)PLBartConfig
configuration class:PLBartModel
(PLBart model)- PatchTSMixerConfig configuration class: PatchTSMixerModel (PatchTSMixer model)
- PatchTSTConfig configuration class: PatchTSTModel (PatchTST model)
PegasusConfig
configuration class:PegasusModel
(Pegasus model)PegasusXConfig
configuration class:PegasusXModel
(PEGASUS-X model)PerceiverConfig
configuration class:PerceiverModel
(Perceiver model)PersimmonConfig
configuration class:PersimmonModel
(Persimmon model)Phi3Config
configuration class:Phi3Model
(Phi3 model)PhiConfig
configuration class:PhiModel
(Phi model)PhimoeConfig
configuration class:PhimoeModel
(Phimoe model)PixtralVisionConfig
configuration class:PixtralVisionModel
(Pixtral model)PoolFormerConfig
configuration class:PoolFormerModel
(PoolFormer model)ProphetNetConfig
configuration class:ProphetNetModel
(ProphetNet model)PvtConfig
configuration class:PvtModel
(PVT model)PvtV2Config
configuration class:PvtV2Model
(PVTv2 model)QDQBertConfig
configuration class:QDQBertModel
(QDQBert model)Qwen2AudioEncoderConfig
configuration class:Qwen2AudioEncoder
(Qwen2AudioEncoder model)Qwen2Config
configuration class:Qwen2Model
(Qwen2 model)Qwen2MoeConfig
configuration class:Qwen2MoeModel
(Qwen2MoE model)Qwen2VLConfig
configuration class:Qwen2VLModel
(Qwen2VL model)RTDetrConfig
configuration class:RTDetrModel
(RT-DETR model)RecurrentGemmaConfig
configuration class:RecurrentGemmaModel
(RecurrentGemma model)ReformerConfig
configuration class:ReformerModel
(Reformer model)RegNetConfig
configuration class:RegNetModel
(RegNet model)RemBertConfig
configuration class:RemBertModel
(RemBERT model)ResNetConfig
configuration class:ResNetModel
(ResNet model)RetriBertConfig
configuration class:RetriBertModel
(RetriBERT model)RoCBertConfig
configuration class:RoCBertModel
(RoCBert model)RoFormerConfig
configuration class:RoFormerModel
(RoFormer model)RobertaConfig
configuration class:RobertaModel
(RoBERTa model)RobertaPreLayerNormConfig
configuration class:RobertaPreLayerNormModel
(RoBERTa-PreLayerNorm model)RwkvConfig
configuration class:RwkvModel
(RWKV model)SEWConfig
configuration class:SEWModel
(SEW model)SEWDConfig
configuration class:SEWDModel
(SEW-D model)SamConfig
configuration class:SamModel
(SAM model)SeamlessM4TConfig
configuration class:SeamlessM4TModel
(SeamlessM4T model)SeamlessM4Tv2Config
configuration class:SeamlessM4Tv2Model
(SeamlessM4Tv2 model)SegGptConfig
configuration class:SegGptModel
(SegGPT model)SegformerConfig
configuration class:SegformerModel
(SegFormer model)SiglipConfig
configuration class:SiglipModel
(SigLIP model)SiglipVisionConfig
configuration class:SiglipVisionModel
(SiglipVisionModel model)Speech2TextConfig
configuration class:Speech2TextModel
(Speech2Text model)SpeechT5Config
configuration class:SpeechT5Model
(SpeechT5 model)SplinterConfig
configuration class:SplinterModel
(Splinter model)SqueezeBertConfig
configuration class:SqueezeBertModel
(SqueezeBERT model)StableLmConfig
configuration class:StableLmModel
(StableLm model)Starcoder2Config
configuration class:Starcoder2Model
(Starcoder2 model)SwiftFormerConfig
configuration class:SwiftFormerModel
(SwiftFormer model)- Swin2SRConfig configuration class: Swin2SRModel (Swin2SR model)
- SwinConfig configuration class: SwinModel (Swin Transformer model)
- Swinv2Config configuration class: Swinv2Model (Swin Transformer V2 model)
SwitchTransformersConfig
configuration class:SwitchTransformersModel
(SwitchTransformers model)T5Config
configuration class:T5Model
(T5 model)TableTransformerConfig
configuration class:TableTransformerModel
(Table Transformer model)TapasConfig
configuration class:TapasModel
(TAPAS model)TextNetConfig
configuration class:TextNetModel
(TextNet model)- TimeSeriesTransformerConfig configuration class: TimeSeriesTransformerModel (Time Series Transformer model)
- TimesformerConfig configuration class: TimesformerModel (TimeSformer model)
TimmBackboneConfig
configuration class:TimmBackbone
(TimmBackbone model)TimmWrapperConfig
configuration class:TimmWrapperModel
(TimmWrapperModel model)- TrajectoryTransformerConfig configuration class: TrajectoryTransformerModel (Trajectory Transformer model)
TransfoXLConfig
configuration class:TransfoXLModel
(Transformer-XL model)TvltConfig
configuration class:TvltModel
(TVLT model)TvpConfig
configuration class:TvpModel
(TVP model)UMT5Config
configuration class:UMT5Model
(UMT5 model)UdopConfig
configuration class:UdopModel
(UDOP model)UniSpeechConfig
configuration class:UniSpeechModel
(UniSpeech model)UniSpeechSatConfig
configuration class:UniSpeechSatModel
(UniSpeechSat model)UnivNetConfig
configuration class:UnivNetModel
(UnivNet model)VanConfig
configuration class:VanModel
(VAN model)- ViTConfig configuration class: ViTModel (ViT model)
ViTHybridConfig
configuration class:ViTHybridModel
(ViT Hybrid model)ViTMAEConfig
configuration class:ViTMAEModel
(ViTMAE model)ViTMSNConfig
configuration class:ViTMSNModel
(ViTMSN model)VideoMAEConfig
configuration class:VideoMAEModel
(VideoMAE model)ViltConfig
configuration class:ViltModel
(ViLT model)VisionTextDualEncoderConfig
configuration class:VisionTextDualEncoderModel
(VisionTextDualEncoder model)VisualBertConfig
configuration class:VisualBertModel
(VisualBERT model)VitDetConfig
configuration class:VitDetModel
(VitDet model)VitsConfig
configuration class:VitsModel
(VITS model)- VivitConfig configuration class: VivitModel (ViViT model)
Wav2Vec2BertConfig
configuration class:Wav2Vec2BertModel
(Wav2Vec2-BERT model)Wav2Vec2Config
configuration class:Wav2Vec2Model
(Wav2Vec2 model)Wav2Vec2ConformerConfig
configuration class:Wav2Vec2ConformerModel
(Wav2Vec2-Conformer model)WavLMConfig
configuration class:WavLMModel
(WavLM model)- WhisperConfig configuration class: WhisperModel (Whisper model)
XCLIPConfig
configuration class:XCLIPModel
(X-CLIP model)XGLMConfig
configuration class:XGLMModel
(XGLM model)XLMConfig
configuration class:XLMModel
(XLM model)XLMProphetNetConfig
configuration class:XLMProphetNetModel
(XLM-ProphetNet model)XLMRobertaConfig
configuration class:XLMRobertaModel
(XLM-RoBERTa model)XLMRobertaXLConfig
configuration class:XLMRobertaXLModel
(XLM-RoBERTa-XL model)XLNetConfig
configuration class:XLNetModel
(XLNet model)XmodConfig
configuration class:XmodModel
(X-MOD model)YolosConfig
configuration class:YolosModel
(YOLOS model)YosoConfig
configuration class:YosoModel
(YOSO model)ZambaConfig
configuration class:ZambaModel
(Zamba model)
- attn_implementation (
str
, optional) — The attention implementation to use in the model (if relevant). Can be any of"eager"
(manual implementation of the attention),"sdpa"
(usingF.scaled_dot_product_attention
), or"flash_attention_2"
(using Dao-AILab/flash-attention). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual"eager"
implementation.
Instantiates one of the base model classes of the library from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the model’s configuration. Use from_pretrained() to load the model weights.
from_pretrained
< source >( *model_args **kwargs )
Parameters
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a tensorflow index checkpoint file (e.g,
./tf_model/model.ckpt.index
). In this case,from_tf
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- state_dict (Dict[str, torch.Tensor], optional) —
A state dictionary to use instead of a state dictionary loaded from saved weights file.
This option can be used if you want to create a model from a pretrained configuration but load your own weights. In this case though, you should check if using save_pretrained() and from_pretrained() is not a simpler option.
- cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - from_tf (
bool
, optional, defaults toFalse
) — Load the model weights from a TensorFlow checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). - force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download — Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. - output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). - revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. - code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the base model classes of the library from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when it’s missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- albert —
AlbertModel
(ALBERT model) - align —
AlignModel
(ALIGN model) - altclip — AltCLIPModel (AltCLIP model)
- aria —
AriaForConditionalGeneration
(Aria model) - aria_text —
AriaTextModel
(AriaText model) - audio-spectrogram-transformer —
ASTModel
(Audio Spectrogram Transformer model) - autoformer — AutoformerModel (Autoformer model)
- bamba —
BambaModel
(Bamba model) - bark —
BarkModel
(Bark model) - bart — BartModel (BART model)
- beit —
BeitModel
(BEiT model) - bert — BertModel (BERT model)
- bert-generation —
BertGenerationEncoder
(Bert Generation model) - big_bird —
BigBirdModel
(BigBird model) - bigbird_pegasus —
BigBirdPegasusModel
(BigBird-Pegasus model) - biogpt — BioGptModel (BioGpt model)
- bit —
BitModel
(BiT model) - blenderbot —
BlenderbotModel
(Blenderbot model) - blenderbot-small —
BlenderbotSmallModel
(BlenderbotSmall model) - blip — BlipModel (BLIP model)
- blip-2 — Blip2Model (BLIP-2 model)
- bloom —
BloomModel
(BLOOM model) - bridgetower —
BridgeTowerModel
(BridgeTower model) - bros —
BrosModel
(BROS model) - camembert —
CamembertModel
(CamemBERT model) - canine —
CanineModel
(CANINE model) - chameleon — ChameleonModel (Chameleon model)
- chinese_clip —
ChineseCLIPModel
(Chinese-CLIP model) - chinese_clip_vision_model —
ChineseCLIPVisionModel
(ChineseCLIPVisionModel model) - clap —
ClapModel
(CLAP model) - clip — CLIPModel (CLIP model)
- clip_text_model — CLIPTextModel (CLIPTextModel model)
- clip_vision_model — CLIPVisionModel (CLIPVisionModel model)
- clipseg —
CLIPSegModel
(CLIPSeg model) - clvp —
ClvpModelForConditionalGeneration
(CLVP model) - code_llama — LlamaModel (CodeLlama model)
- codegen —
CodeGenModel
(CodeGen model) - cohere — CohereModel (Cohere model)
- cohere2 —
Cohere2Model
(Cohere2 model) - conditional_detr —
ConditionalDetrModel
(Conditional DETR model) - convbert — ConvBertModel (ConvBERT model)
- convnext —
ConvNextModel
(ConvNeXT model) - convnextv2 —
ConvNextV2Model
(ConvNeXTV2 model) - cpmant —
CpmAntModel
(CPM-Ant model) - ctrl —
CTRLModel
(CTRL model) - cvt —
CvtModel
(CvT model) - dac —
DacModel
(DAC model) - data2vec-audio —
Data2VecAudioModel
(Data2VecAudio model) - data2vec-text —
Data2VecTextModel
(Data2VecText model) - data2vec-vision —
Data2VecVisionModel
(Data2VecVision model) - dbrx — DbrxModel (DBRX model)
- deberta — DebertaModel (DeBERTa model)
- deberta-v2 — DebertaV2Model (DeBERTa-v2 model)
- decision_transformer —
DecisionTransformerModel
(Decision Transformer model) - deformable_detr —
DeformableDetrModel
(Deformable DETR model) - deit —
DeiTModel
(DeiT model) - deta —
DetaModel
(DETA model) - detr —
DetrModel
(DETR model) - diffllama —
DiffLlamaModel
(DiffLlama model) - dinat —
DinatModel
(DiNAT model) - dinov2 —
Dinov2Model
(DINOv2 model) - dinov2_with_registers —
Dinov2WithRegistersModel
(DINOv2 with Registers model) - distilbert —
DistilBertModel
(DistilBERT model) - donut-swin —
DonutSwinModel
(DonutSwin model) - dpr —
DPRQuestionEncoder
(DPR model) - dpt —
DPTModel
(DPT model) - efficientformer —
EfficientFormerModel
(EfficientFormer model) - efficientnet —
EfficientNetModel
(EfficientNet model) - electra —
ElectraModel
(ELECTRA model) - encodec —
EncodecModel
(EnCodec model) - ernie —
ErnieModel
(ERNIE model) - ernie_m —
ErnieMModel
(ErnieM model) - esm — EsmModel (ESM model)
- falcon —
FalconModel
(Falcon model) - falcon_mamba —
FalconMambaModel
(FalconMamba model) - fastspeech2_conformer —
FastSpeech2ConformerModel
(FastSpeech2Conformer model) - flaubert —
FlaubertModel
(FlauBERT model) - flava —
FlavaModel
(FLAVA model) - fnet —
FNetModel
(FNet model) - focalnet —
FocalNetModel
(FocalNet model) - fsmt —
FSMTModel
(FairSeq Machine-Translation model) - funnel —
FunnelModel
orFunnelBaseModel
(Funnel Transformer model) - gemma — GemmaModel (Gemma model)
- gemma2 — Gemma2Model (Gemma2 model)
- git —
GitModel
(GIT model) - glm —
GlmModel
(GLM model) - glpn —
GLPNModel
(GLPN model) - gpt-sw3 —
GPT2Model
(GPT-Sw3 model) - gpt2 —
GPT2Model
(OpenAI GPT-2 model) - gpt_bigcode —
GPTBigCodeModel
(GPTBigCode model) - gpt_neo —
GPTNeoModel
(GPT Neo model) - gpt_neox —
GPTNeoXModel
(GPT NeoX model) - gpt_neox_japanese — GPTNeoXJapaneseModel (GPT NeoX Japanese model)
- gptj —
GPTJModel
(GPT-J model) - gptsan-japanese —
GPTSanJapaneseForConditionalGeneration
(GPTSAN-japanese model) - granite —
GraniteModel
(Granite model) - granitemoe —
GraniteMoeModel
(GraniteMoeMoe model) - graphormer — GraphormerModel (Graphormer model)
- grounding-dino —
GroundingDinoModel
(Grounding DINO model) - groupvit —
GroupViTModel
(GroupViT model) - helium —
HeliumModel
(Helium model) - hiera —
HieraModel
(Hiera model) - hubert —
HubertModel
(Hubert model) - ibert —
IBertModel
(I-BERT model) - idefics —
IdeficsModel
(IDEFICS model) - idefics2 —
Idefics2Model
(Idefics2 model) - idefics3 —
Idefics3Model
(Idefics3 model) - idefics3_vision —
Idefics3VisionTransformer
(Idefics3VisionTransformer model) - ijepa —
IJepaModel
(I-JEPA model) - imagegpt —
ImageGPTModel
(ImageGPT model) - informer — InformerModel (Informer model)
- jamba —
JambaModel
(Jamba model) - jetmoe —
JetMoeModel
(JetMoe model) - jukebox —
JukeboxModel
(Jukebox model) - kosmos-2 —
Kosmos2Model
(KOSMOS-2 model) - layoutlm —
LayoutLMModel
(LayoutLM model) - layoutlmv2 —
LayoutLMv2Model
(LayoutLMv2 model) - layoutlmv3 —
LayoutLMv3Model
(LayoutLMv3 model) - led —
LEDModel
(LED model) - levit —
LevitModel
(LeViT model) - lilt —
LiltModel
(LiLT model) - llama — LlamaModel (LLaMA model)
- longformer —
LongformerModel
(Longformer model) - longt5 —
LongT5Model
(LongT5 model) - luke —
LukeModel
(LUKE model) - lxmert —
LxmertModel
(LXMERT model) - m2m_100 —
M2M100Model
(M2M100 model) - mamba — MambaModel (Mamba model)
- mamba2 — Mamba2Model (mamba2 model)
- marian — MarianModel (Marian model)
- markuplm —
MarkupLMModel
(MarkupLM model) - mask2former —
Mask2FormerModel
(Mask2Former model) - maskformer —
MaskFormerModel
(MaskFormer model) - maskformer-swin —
MaskFormerSwinModel
(MaskFormerSwin model) - mbart —
MBartModel
(mBART model) - mctct —
MCTCTModel
(M-CTC-T model) - mega —
MegaModel
(MEGA model) - megatron-bert —
MegatronBertModel
(Megatron-BERT model) - mgp-str —
MgpstrForSceneTextRecognition
(MGP-STR model) - mimi —
MimiModel
(Mimi model) - mistral — MistralModel (Mistral model)
- mixtral —
MixtralModel
(Mixtral model) - mobilebert —
MobileBertModel
(MobileBERT model) - mobilenet_v1 —
MobileNetV1Model
(MobileNetV1 model) - mobilenet_v2 —
MobileNetV2Model
(MobileNetV2 model) - mobilevit —
MobileViTModel
(MobileViT model) - mobilevitv2 —
MobileViTV2Model
(MobileViTV2 model) - modernbert —
ModernBertModel
(ModernBERT model) - moonshine —
MoonshineModel
(Moonshine model) - moshi —
MoshiModel
(Moshi model) - mpnet —
MPNetModel
(MPNet model) - mpt —
MptModel
(MPT model) - mra —
MraModel
(MRA model) - mt5 —
MT5Model
(MT5 model) - musicgen —
MusicgenModel
(MusicGen model) - musicgen_melody —
MusicgenMelodyModel
(MusicGen Melody model) - mvp —
MvpModel
(MVP model) - nat —
NatModel
(NAT model) - nemotron —
NemotronModel
(Nemotron model) - nezha —
NezhaModel
(Nezha model) - nllb-moe —
NllbMoeModel
(NLLB-MOE model) - nystromformer —
NystromformerModel
(Nyströmformer model) - olmo —
OlmoModel
(OLMo model) - olmo2 —
Olmo2Model
(OLMo2 model) - olmoe —
OlmoeModel
(OLMoE model) - omdet-turbo —
OmDetTurboForObjectDetection
(OmDet-Turbo model) - oneformer —
OneFormerModel
(OneFormer model) - open-llama —
OpenLlamaModel
(OpenLlama model) - openai-gpt — OpenAIGPTModel (OpenAI GPT model)
- opt —
OPTModel
(OPT model) - owlv2 —
Owlv2Model
(OWLv2 model) - owlvit —
OwlViTModel
(OWL-ViT model) - patchtsmixer — PatchTSMixerModel (PatchTSMixer model)
- patchtst — PatchTSTModel (PatchTST model)
- pegasus —
PegasusModel
(Pegasus model) - pegasus_x —
PegasusXModel
(PEGASUS-X model) - perceiver —
PerceiverModel
(Perceiver model) - persimmon —
PersimmonModel
(Persimmon model) - phi —
PhiModel
(Phi model) - phi3 —
Phi3Model
(Phi3 model) - phimoe —
PhimoeModel
(Phimoe model) - pixtral —
PixtralVisionModel
(Pixtral model) - plbart —
PLBartModel
(PLBart model) - poolformer —
PoolFormerModel
(PoolFormer model) - prophetnet —
ProphetNetModel
(ProphetNet model) - pvt —
PvtModel
(PVT model) - pvt_v2 —
PvtV2Model
(PVTv2 model) - qdqbert —
QDQBertModel
(QDQBert model) - qwen2 —
Qwen2Model
(Qwen2 model) - qwen2_audio_encoder —
Qwen2AudioEncoder
(Qwen2AudioEncoder model) - qwen2_moe —
Qwen2MoeModel
(Qwen2MoE model) - qwen2_vl —
Qwen2VLModel
(Qwen2VL model) - recurrent_gemma —
RecurrentGemmaModel
(RecurrentGemma model) - reformer —
ReformerModel
(Reformer model) - regnet —
RegNetModel
(RegNet model) - rembert —
RemBertModel
(RemBERT model) - resnet —
ResNetModel
(ResNet model) - retribert —
RetriBertModel
(RetriBERT model) - roberta —
RobertaModel
(RoBERTa model) - roberta-prelayernorm —
RobertaPreLayerNormModel
(RoBERTa-PreLayerNorm model) - roc_bert —
RoCBertModel
(RoCBert model) - roformer —
RoFormerModel
(RoFormer model) - rt_detr —
RTDetrModel
(RT-DETR model) - rwkv —
RwkvModel
(RWKV model) - sam —
SamModel
(SAM model) - seamless_m4t —
SeamlessM4TModel
(SeamlessM4T model) - seamless_m4t_v2 —
SeamlessM4Tv2Model
(SeamlessM4Tv2 model) - segformer —
SegformerModel
(SegFormer model) - seggpt —
SegGptModel
(SegGPT model) - sew —
SEWModel
(SEW model) - sew-d —
SEWDModel
(SEW-D model) - siglip —
SiglipModel
(SigLIP model) - siglip_vision_model —
SiglipVisionModel
(SiglipVisionModel model) - speech_to_text —
Speech2TextModel
(Speech2Text model) - speecht5 —
SpeechT5Model
(SpeechT5 model) - splinter —
SplinterModel
(Splinter model) - squeezebert —
SqueezeBertModel
(SqueezeBERT model) - stablelm —
StableLmModel
(StableLm model) - starcoder2 —
Starcoder2Model
(Starcoder2 model) - swiftformer —
SwiftFormerModel
(SwiftFormer model) - swin — SwinModel (Swin Transformer model)
- swin2sr — Swin2SRModel (Swin2SR model)
- swinv2 — Swinv2Model (Swin Transformer V2 model)
- switch_transformers —
SwitchTransformersModel
(SwitchTransformers model) - t5 —
T5Model
(T5 model) - table-transformer —
TableTransformerModel
(Table Transformer model) - tapas —
TapasModel
(TAPAS model) - textnet —
TextNetModel
(TextNet model) - time_series_transformer — TimeSeriesTransformerModel (Time Series Transformer model)
- timesformer — TimesformerModel (TimeSformer model)
- timm_backbone —
TimmBackbone
(TimmBackbone model) - timm_wrapper —
TimmWrapperModel
(TimmWrapperModel model) - trajectory_transformer — TrajectoryTransformerModel (Trajectory Transformer model)
- transfo-xl —
TransfoXLModel
(Transformer-XL model) - tvlt —
TvltModel
(TVLT model) - tvp —
TvpModel
(TVP model) - udop —
UdopModel
(UDOP model) - umt5 —
UMT5Model
(UMT5 model) - unispeech —
UniSpeechModel
(UniSpeech model) - unispeech-sat —
UniSpeechSatModel
(UniSpeechSat model) - univnet —
UnivNetModel
(UnivNet model) - van —
VanModel
(VAN model) - videomae —
VideoMAEModel
(VideoMAE model) - vilt —
ViltModel
(ViLT model) - vision-text-dual-encoder —
VisionTextDualEncoderModel
(VisionTextDualEncoder model) - visual_bert —
VisualBertModel
(VisualBERT model) - vit — ViTModel (ViT model)
- vit_hybrid —
ViTHybridModel
(ViT Hybrid model) - vit_mae —
ViTMAEModel
(ViTMAE model) - vit_msn —
ViTMSNModel
(ViTMSN model) - vitdet —
VitDetModel
(VitDet model) - vits —
VitsModel
(VITS model) - vivit — VivitModel (ViViT model)
- wav2vec2 —
Wav2Vec2Model
(Wav2Vec2 model) - wav2vec2-bert —
Wav2Vec2BertModel
(Wav2Vec2-BERT model) - wav2vec2-conformer —
Wav2Vec2ConformerModel
(Wav2Vec2-Conformer model) - wavlm —
WavLMModel
(WavLM model) - whisper — WhisperModel (Whisper model)
- xclip —
XCLIPModel
(X-CLIP model) - xglm —
XGLMModel
(XGLM model) - xlm —
XLMModel
(XLM model) - xlm-prophetnet —
XLMProphetNetModel
(XLM-ProphetNet model) - xlm-roberta —
XLMRobertaModel
(XLM-RoBERTa model) - xlm-roberta-xl —
XLMRobertaXLModel
(XLM-RoBERTa-XL model) - xlnet —
XLNetModel
(XLNet model) - xmod —
XmodModel
(X-MOD model) - yolos —
YolosModel
(YOLOS model) - yoso —
YosoModel
(YOSO model) - zamba —
ZambaModel
(Zamba model)
The model is set in evaluation mode by default using model.eval()
(so for instance, dropout modules are
deactivated). To train the model, you should first set it back in training mode with model.train()
Examples:
>>> from transformers import AutoConfig, AutoModel
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModel.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModel.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModel.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
TFAutoModel
This is a generic model class that will be instantiated as one of the base model classes of the library when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
AlbertConfig
configuration class:TFAlbertModel
(ALBERT model)- BartConfig configuration class: TFBartModel (BART model)
- BertConfig configuration class: TFBertModel (BERT model)
BlenderbotConfig
configuration class:TFBlenderbotModel
(Blenderbot model)BlenderbotSmallConfig
configuration class:TFBlenderbotSmallModel
(BlenderbotSmall model)- BlipConfig configuration class: TFBlipModel (BLIP model)
- CLIPConfig configuration class: TFCLIPModel (CLIP model)
CTRLConfig
configuration class:TFCTRLModel
(CTRL model)CamembertConfig
configuration class:TFCamembertModel
(CamemBERT model)- ConvBertConfig configuration class: TFConvBertModel (ConvBERT model)
ConvNextConfig
configuration class:TFConvNextModel
(ConvNeXT model)ConvNextV2Config
configuration class:TFConvNextV2Model
(ConvNeXTV2 model)CvtConfig
configuration class:TFCvtModel
(CvT model)DPRConfig
configuration class:TFDPRQuestionEncoder
(DPR model)Data2VecVisionConfig
configuration class:TFData2VecVisionModel
(Data2VecVision model)- DebertaConfig configuration class: TFDebertaModel (DeBERTa model)
- DebertaV2Config configuration class: TFDebertaV2Model (DeBERTa-v2 model)
DeiTConfig
configuration class:TFDeiTModel
(DeiT model)DistilBertConfig
configuration class:TFDistilBertModel
(DistilBERT model)EfficientFormerConfig
configuration class:TFEfficientFormerModel
(EfficientFormer model)ElectraConfig
configuration class:TFElectraModel
(ELECTRA model)- EsmConfig configuration class: TFEsmModel (ESM model)
FlaubertConfig
configuration class:TFFlaubertModel
(FlauBERT model)FunnelConfig
configuration class:TFFunnelModel
orTFFunnelBaseModel
(Funnel Transformer model)GPT2Config
configuration class:TFGPT2Model
(OpenAI GPT-2 model)GPTJConfig
configuration class:TFGPTJModel
(GPT-J model)GroupViTConfig
configuration class:TFGroupViTModel
(GroupViT model)HubertConfig
configuration class:TFHubertModel
(Hubert model)IdeficsConfig
configuration class:TFIdeficsModel
(IDEFICS model)LEDConfig
configuration class:TFLEDModel
(LED model)LayoutLMConfig
configuration class:TFLayoutLMModel
(LayoutLM model)LayoutLMv3Config
configuration class:TFLayoutLMv3Model
(LayoutLMv3 model)LongformerConfig
configuration class:TFLongformerModel
(Longformer model)LxmertConfig
configuration class:TFLxmertModel
(LXMERT model)MBartConfig
configuration class:TFMBartModel
(mBART model)MPNetConfig
configuration class:TFMPNetModel
(MPNet model)MT5Config
configuration class:TFMT5Model
(MT5 model)- MarianConfig configuration class: TFMarianModel (Marian model)
- MistralConfig configuration class: TFMistralModel (Mistral model)
MobileBertConfig
configuration class:TFMobileBertModel
(MobileBERT model)MobileViTConfig
configuration class:TFMobileViTModel
(MobileViT model)OPTConfig
configuration class:TFOPTModel
(OPT model)- OpenAIGPTConfig configuration class: TFOpenAIGPTModel (OpenAI GPT model)
PegasusConfig
configuration class:TFPegasusModel
(Pegasus model)RegNetConfig
configuration class:TFRegNetModel
(RegNet model)RemBertConfig
configuration class:TFRemBertModel
(RemBERT model)ResNetConfig
configuration class:TFResNetModel
(ResNet model)RoFormerConfig
configuration class:TFRoFormerModel
(RoFormer model)RobertaConfig
configuration class:TFRobertaModel
(RoBERTa model)RobertaPreLayerNormConfig
configuration class:TFRobertaPreLayerNormModel
(RoBERTa-PreLayerNorm model)SamConfig
configuration class:TFSamModel
(SAM model)SegformerConfig
configuration class:TFSegformerModel
(SegFormer model)Speech2TextConfig
configuration class:TFSpeech2TextModel
(Speech2Text model)SwiftFormerConfig
configuration class:TFSwiftFormerModel
(SwiftFormer model)- SwinConfig configuration class: TFSwinModel (Swin Transformer model)
T5Config
configuration class:TFT5Model
(T5 model)TapasConfig
configuration class:TFTapasModel
(TAPAS model)TransfoXLConfig
configuration class:TFTransfoXLModel
(Transformer-XL model)- ViTConfig configuration class: TFViTModel (ViT model)
ViTMAEConfig
configuration class:TFViTMAEModel
(ViTMAE model)VisionTextDualEncoderConfig
configuration class:TFVisionTextDualEncoderModel
(VisionTextDualEncoder model)Wav2Vec2Config
configuration class:TFWav2Vec2Model
(Wav2Vec2 model)- WhisperConfig configuration class: TFWhisperModel (Whisper model)
XGLMConfig
configuration class:TFXGLMModel
(XGLM model)XLMConfig
configuration class:TFXLMModel
(XLM model)XLMRobertaConfig
configuration class:TFXLMRobertaModel
(XLM-RoBERTa model)XLNetConfig
configuration class:TFXLNetModel
(XLNet model)
- attn_implementation (
str
, optional) — The attention implementation to use in the model (if relevant). Can be any of"eager"
(manual implementation of the attention),"sdpa"
(usingF.scaled_dot_product_attention
), or"flash_attention_2"
(using Dao-AILab/flash-attention). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual"eager"
implementation.
Instantiates one of the base model classes of the library from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the model’s configuration. Use from_pretrained() to load the model weights.
from_pretrained
< source >( *model_args **kwargs )
Parameters
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - from_pt (
bool
, optional, defaults toFalse
) — Load the model weights from a PyTorch checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). - force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download — Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. - output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). - revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. - code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the base model classes of the library from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when it’s missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- albert —
TFAlbertModel
(ALBERT model) - bart — TFBartModel (BART model)
- bert — TFBertModel (BERT model)
- blenderbot —
TFBlenderbotModel
(Blenderbot model) - blenderbot-small —
TFBlenderbotSmallModel
(BlenderbotSmall model) - blip — TFBlipModel (BLIP model)
- camembert —
TFCamembertModel
(CamemBERT model) - clip — TFCLIPModel (CLIP model)
- convbert — TFConvBertModel (ConvBERT model)
- convnext —
TFConvNextModel
(ConvNeXT model) - convnextv2 —
TFConvNextV2Model
(ConvNeXTV2 model) - ctrl —
TFCTRLModel
(CTRL model) - cvt —
TFCvtModel
(CvT model) - data2vec-vision —
TFData2VecVisionModel
(Data2VecVision model) - deberta — TFDebertaModel (DeBERTa model)
- deberta-v2 — TFDebertaV2Model (DeBERTa-v2 model)
- deit —
TFDeiTModel
(DeiT model) - distilbert —
TFDistilBertModel
(DistilBERT model) - dpr —
TFDPRQuestionEncoder
(DPR model) - efficientformer —
TFEfficientFormerModel
(EfficientFormer model) - electra —
TFElectraModel
(ELECTRA model) - esm — TFEsmModel (ESM model)
- flaubert —
TFFlaubertModel
(FlauBERT model) - funnel —
TFFunnelModel
orTFFunnelBaseModel
(Funnel Transformer model) - gpt-sw3 —
TFGPT2Model
(GPT-Sw3 model) - gpt2 —
TFGPT2Model
(OpenAI GPT-2 model) - gptj —
TFGPTJModel
(GPT-J model) - groupvit —
TFGroupViTModel
(GroupViT model) - hubert —
TFHubertModel
(Hubert model) - idefics —
TFIdeficsModel
(IDEFICS model) - layoutlm —
TFLayoutLMModel
(LayoutLM model) - layoutlmv3 —
TFLayoutLMv3Model
(LayoutLMv3 model) - led —
TFLEDModel
(LED model) - longformer —
TFLongformerModel
(Longformer model) - lxmert —
TFLxmertModel
(LXMERT model) - marian — TFMarianModel (Marian model)
- mbart —
TFMBartModel
(mBART model) - mistral — TFMistralModel (Mistral model)
- mobilebert —
TFMobileBertModel
(MobileBERT model) - mobilevit —
TFMobileViTModel
(MobileViT model) - mpnet —
TFMPNetModel
(MPNet model) - mt5 —
TFMT5Model
(MT5 model) - openai-gpt — TFOpenAIGPTModel (OpenAI GPT model)
- opt —
TFOPTModel
(OPT model) - pegasus —
TFPegasusModel
(Pegasus model) - regnet —
TFRegNetModel
(RegNet model) - rembert —
TFRemBertModel
(RemBERT model) - resnet —
TFResNetModel
(ResNet model) - roberta —
TFRobertaModel
(RoBERTa model) - roberta-prelayernorm —
TFRobertaPreLayerNormModel
(RoBERTa-PreLayerNorm model) - roformer —
TFRoFormerModel
(RoFormer model) - sam —
TFSamModel
(SAM model) - segformer —
TFSegformerModel
(SegFormer model) - speech_to_text —
TFSpeech2TextModel
(Speech2Text model) - swiftformer —
TFSwiftFormerModel
(SwiftFormer model) - swin — TFSwinModel (Swin Transformer model)
- t5 —
TFT5Model
(T5 model) - tapas —
TFTapasModel
(TAPAS model) - transfo-xl —
TFTransfoXLModel
(Transformer-XL model) - vision-text-dual-encoder —
TFVisionTextDualEncoderModel
(VisionTextDualEncoder model) - vit — TFViTModel (ViT model)
- vit_mae —
TFViTMAEModel
(ViTMAE model) - wav2vec2 —
TFWav2Vec2Model
(Wav2Vec2 model) - whisper — TFWhisperModel (Whisper model)
- xglm —
TFXGLMModel
(XGLM model) - xlm —
TFXLMModel
(XLM model) - xlm-roberta —
TFXLMRobertaModel
(XLM-RoBERTa model) - xlnet —
TFXLNetModel
(XLNet model)
Examples:
>>> from transformers import AutoConfig, TFAutoModel
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModel.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = TFAutoModel.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = TFAutoModel.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
FlaxAutoModel
This is a generic model class that will be instantiated as one of the base model classes of the library when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
AlbertConfig
configuration class:FlaxAlbertModel
(ALBERT model)- BartConfig configuration class: FlaxBartModel (BART model)
BeitConfig
configuration class:FlaxBeitModel
(BEiT model)- BertConfig configuration class: FlaxBertModel (BERT model)
BigBirdConfig
configuration class:FlaxBigBirdModel
(BigBird model)BlenderbotConfig
configuration class:FlaxBlenderbotModel
(Blenderbot model)BlenderbotSmallConfig
configuration class:FlaxBlenderbotSmallModel
(BlenderbotSmall model)BloomConfig
configuration class:FlaxBloomModel
(BLOOM model)- CLIPConfig configuration class: FlaxCLIPModel (CLIP model)
Dinov2Config
configuration class:FlaxDinov2Model
(DINOv2 model)DistilBertConfig
configuration class:FlaxDistilBertModel
(DistilBERT model)ElectraConfig
configuration class:FlaxElectraModel
(ELECTRA model)GPT2Config
configuration class:FlaxGPT2Model
(OpenAI GPT-2 model)GPTJConfig
configuration class:FlaxGPTJModel
(GPT-J model)GPTNeoConfig
configuration class:FlaxGPTNeoModel
(GPT Neo model)- GemmaConfig configuration class: FlaxGemmaModel (Gemma model)
- LlamaConfig configuration class:
FlaxLlamaModel
(LLaMA model) LongT5Config
configuration class:FlaxLongT5Model
(LongT5 model)MBartConfig
configuration class:FlaxMBartModel
(mBART model)MT5Config
configuration class:FlaxMT5Model
(MT5 model)- MarianConfig configuration class: FlaxMarianModel (Marian model)
- MistralConfig configuration class: FlaxMistralModel (Mistral model)
OPTConfig
configuration class:FlaxOPTModel
(OPT model)PegasusConfig
configuration class:FlaxPegasusModel
(Pegasus model)RegNetConfig
configuration class:FlaxRegNetModel
(RegNet model)ResNetConfig
configuration class:FlaxResNetModel
(ResNet model)RoFormerConfig
configuration class:FlaxRoFormerModel
(RoFormer model)RobertaConfig
configuration class:FlaxRobertaModel
(RoBERTa model)RobertaPreLayerNormConfig
configuration class:FlaxRobertaPreLayerNormModel
(RoBERTa-PreLayerNorm model)T5Config
configuration class:FlaxT5Model
(T5 model)- ViTConfig configuration class: FlaxViTModel (ViT model)
VisionTextDualEncoderConfig
configuration class:FlaxVisionTextDualEncoderModel
(VisionTextDualEncoder model)Wav2Vec2Config
configuration class:FlaxWav2Vec2Model
(Wav2Vec2 model)- WhisperConfig configuration class: FlaxWhisperModel (Whisper model)
XGLMConfig
configuration class:FlaxXGLMModel
(XGLM model)XLMRobertaConfig
configuration class:FlaxXLMRobertaModel
(XLM-RoBERTa model)
- attn_implementation (
str
, optional) — The attention implementation to use in the model (if relevant). Can be any of"eager"
(manual implementation of the attention),"sdpa"
(usingF.scaled_dot_product_attention
), or"flash_attention_2"
(using Dao-AILab/flash-attention). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual"eager"
implementation.
Instantiates one of the base model classes of the library from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the model’s configuration. Use from_pretrained() to load the model weights.
from_pretrained
< source >( *model_args **kwargs )
Parameters
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - from_pt (
bool
, optional, defaults toFalse
) — Load the model weights from a PyTorch checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). - force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download — Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. - output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). - revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. - code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the base model classes of the library from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when it’s missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- albert —
FlaxAlbertModel
(ALBERT model) - bart — FlaxBartModel (BART model)
- beit —
FlaxBeitModel
(BEiT model) - bert — FlaxBertModel (BERT model)
- big_bird —
FlaxBigBirdModel
(BigBird model) - blenderbot —
FlaxBlenderbotModel
(Blenderbot model) - blenderbot-small —
FlaxBlenderbotSmallModel
(BlenderbotSmall model) - bloom —
FlaxBloomModel
(BLOOM model) - clip — FlaxCLIPModel (CLIP model)
- dinov2 —
FlaxDinov2Model
(DINOv2 model) - distilbert —
FlaxDistilBertModel
(DistilBERT model) - electra —
FlaxElectraModel
(ELECTRA model) - gemma — FlaxGemmaModel (Gemma model)
- gpt-sw3 —
FlaxGPT2Model
(GPT-Sw3 model) - gpt2 —
FlaxGPT2Model
(OpenAI GPT-2 model) - gpt_neo —
FlaxGPTNeoModel
(GPT Neo model) - gptj —
FlaxGPTJModel
(GPT-J model) - llama —
FlaxLlamaModel
(LLaMA model) - longt5 —
FlaxLongT5Model
(LongT5 model) - marian — FlaxMarianModel (Marian model)
- mbart —
FlaxMBartModel
(mBART model) - mistral — FlaxMistralModel (Mistral model)
- mt5 —
FlaxMT5Model
(MT5 model) - opt —
FlaxOPTModel
(OPT model) - pegasus —
FlaxPegasusModel
(Pegasus model) - regnet —
FlaxRegNetModel
(RegNet model) - resnet —
FlaxResNetModel
(ResNet model) - roberta —
FlaxRobertaModel
(RoBERTa model) - roberta-prelayernorm —
FlaxRobertaPreLayerNormModel
(RoBERTa-PreLayerNorm model) - roformer —
FlaxRoFormerModel
(RoFormer model) - t5 —
FlaxT5Model
(T5 model) - vision-text-dual-encoder —
FlaxVisionTextDualEncoderModel
(VisionTextDualEncoder model) - vit — FlaxViTModel (ViT model)
- wav2vec2 —
FlaxWav2Vec2Model
(Wav2Vec2 model) - whisper — FlaxWhisperModel (Whisper model)
- xglm —
FlaxXGLMModel
(XGLM model) - xlm-roberta —
FlaxXLMRobertaModel
(XLM-RoBERTa model)
Examples:
>>> from transformers import AutoConfig, FlaxAutoModel
>>> # Download model and configuration from huggingface.co and cache.
>>> model = FlaxAutoModel.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = FlaxAutoModel.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = FlaxAutoModel.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
일반적인 사전 학습 클래스
다음 자동 클래스들은 사전 훈련 헤드가 포함된 모델을 인스턴스화하는 데 사용할 수 있습니다.
AutoModelForPreTraining
This is a generic model class that will be instantiated as one of the model classes of the library (with a pretraining head) when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
AlbertConfig
configuration class:AlbertForPreTraining
(ALBERT model)- BartConfig configuration class: BartForConditionalGeneration (BART model)
- BertConfig configuration class: BertForPreTraining (BERT model)
BigBirdConfig
configuration class:BigBirdForPreTraining
(BigBird model)BloomConfig
configuration class:BloomForCausalLM
(BLOOM model)CTRLConfig
configuration class:CTRLLMHeadModel
(CTRL model)CamembertConfig
configuration class:CamembertForMaskedLM
(CamemBERT model)ColPaliConfig
configuration class:ColPaliForRetrieval
(ColPali model)Data2VecTextConfig
configuration class:Data2VecTextForMaskedLM
(Data2VecText model)- DebertaConfig configuration class: DebertaForMaskedLM (DeBERTa model)
- DebertaV2Config configuration class: DebertaV2ForMaskedLM (DeBERTa-v2 model)
DistilBertConfig
configuration class:DistilBertForMaskedLM
(DistilBERT model)ElectraConfig
configuration class:ElectraForPreTraining
(ELECTRA model)ErnieConfig
configuration class:ErnieForPreTraining
(ERNIE model)FNetConfig
configuration class:FNetForPreTraining
(FNet model)FSMTConfig
configuration class:FSMTForConditionalGeneration
(FairSeq Machine-Translation model)FalconMambaConfig
configuration class:FalconMambaForCausalLM
(FalconMamba model)FlaubertConfig
configuration class:FlaubertWithLMHeadModel
(FlauBERT model)FlavaConfig
configuration class:FlavaForPreTraining
(FLAVA model)FunnelConfig
configuration class:FunnelForPreTraining
(Funnel Transformer model)GPT2Config
configuration class:GPT2LMHeadModel
(OpenAI GPT-2 model)GPTBigCodeConfig
configuration class:GPTBigCodeForCausalLM
(GPTBigCode model)GPTSanJapaneseConfig
configuration class:GPTSanJapaneseForConditionalGeneration
(GPTSAN-japanese model)HieraConfig
configuration class:HieraForPreTraining
(Hiera model)IBertConfig
configuration class:IBertForMaskedLM
(I-BERT model)Idefics2Config
configuration class:Idefics2ForConditionalGeneration
(Idefics2 model)Idefics3Config
configuration class:Idefics3ForConditionalGeneration
(Idefics3 model)IdeficsConfig
configuration class:IdeficsForVisionText2Text
(IDEFICS model)LayoutLMConfig
configuration class:LayoutLMForMaskedLM
(LayoutLM model)LlavaConfig
configuration class:LlavaForConditionalGeneration
(LLaVa model)LlavaNextConfig
configuration class:LlavaNextForConditionalGeneration
(LLaVA-NeXT model)LlavaNextVideoConfig
configuration class:LlavaNextVideoForConditionalGeneration
(LLaVa-NeXT-Video model)LlavaOnevisionConfig
configuration class:LlavaOnevisionForConditionalGeneration
(LLaVA-Onevision model)LongformerConfig
configuration class:LongformerForMaskedLM
(Longformer model)LukeConfig
configuration class:LukeForMaskedLM
(LUKE model)LxmertConfig
configuration class:LxmertForPreTraining
(LXMERT model)MPNetConfig
configuration class:MPNetForMaskedLM
(MPNet model)- Mamba2Config configuration class: Mamba2ForCausalLM (mamba2 model)
- MambaConfig configuration class: MambaForCausalLM (Mamba model)
MegaConfig
configuration class:MegaForMaskedLM
(MEGA model)MegatronBertConfig
configuration class:MegatronBertForPreTraining
(Megatron-BERT model)MllamaConfig
configuration class:MllamaForConditionalGeneration
(Mllama model)MobileBertConfig
configuration class:MobileBertForPreTraining
(MobileBERT model)MptConfig
configuration class:MptForCausalLM
(MPT model)MraConfig
configuration class:MraForMaskedLM
(MRA model)MvpConfig
configuration class:MvpForConditionalGeneration
(MVP model)NezhaConfig
configuration class:NezhaForPreTraining
(Nezha model)NllbMoeConfig
configuration class:NllbMoeForConditionalGeneration
(NLLB-MOE model)- OpenAIGPTConfig configuration class: OpenAIGPTLMHeadModel (OpenAI GPT model)
- PaliGemmaConfig configuration class: PaliGemmaForConditionalGeneration (PaliGemma model)
Qwen2AudioConfig
configuration class:Qwen2AudioForConditionalGeneration
(Qwen2Audio model)RetriBertConfig
configuration class:RetriBertModel
(RetriBERT model)RoCBertConfig
configuration class:RoCBertForPreTraining
(RoCBert model)RobertaConfig
configuration class:RobertaForMaskedLM
(RoBERTa model)RobertaPreLayerNormConfig
configuration class:RobertaPreLayerNormForMaskedLM
(RoBERTa-PreLayerNorm model)RwkvConfig
configuration class:RwkvForCausalLM
(RWKV model)SplinterConfig
configuration class:SplinterForPreTraining
(Splinter model)SqueezeBertConfig
configuration class:SqueezeBertForMaskedLM
(SqueezeBERT model)SwitchTransformersConfig
configuration class:SwitchTransformersForConditionalGeneration
(SwitchTransformers model)T5Config
configuration class:T5ForConditionalGeneration
(T5 model)TapasConfig
configuration class:TapasForMaskedLM
(TAPAS model)TransfoXLConfig
configuration class:TransfoXLLMHeadModel
(Transformer-XL model)TvltConfig
configuration class:TvltForPreTraining
(TVLT model)UniSpeechConfig
configuration class:UniSpeechForPreTraining
(UniSpeech model)UniSpeechSatConfig
configuration class:UniSpeechSatForPreTraining
(UniSpeechSat model)ViTMAEConfig
configuration class:ViTMAEForPreTraining
(ViTMAE model)VideoLlavaConfig
configuration class:VideoLlavaForConditionalGeneration
(VideoLlava model)VideoMAEConfig
configuration class:VideoMAEForPreTraining
(VideoMAE model)VipLlavaConfig
configuration class:VipLlavaForConditionalGeneration
(VipLlava model)VisualBertConfig
configuration class:VisualBertForPreTraining
(VisualBERT model)Wav2Vec2Config
configuration class:Wav2Vec2ForPreTraining
(Wav2Vec2 model)Wav2Vec2ConformerConfig
configuration class:Wav2Vec2ConformerForPreTraining
(Wav2Vec2-Conformer model)XLMConfig
configuration class:XLMWithLMHeadModel
(XLM model)XLMRobertaConfig
configuration class:XLMRobertaForMaskedLM
(XLM-RoBERTa model)XLMRobertaXLConfig
configuration class:XLMRobertaXLForMaskedLM
(XLM-RoBERTa-XL model)XLNetConfig
configuration class:XLNetLMHeadModel
(XLNet model)XmodConfig
configuration class:XmodForMaskedLM
(X-MOD model)
- attn_implementation (
str
, optional) — The attention implementation to use in the model (if relevant). Can be any of"eager"
(manual implementation of the attention),"sdpa"
(usingF.scaled_dot_product_attention
), or"flash_attention_2"
(using Dao-AILab/flash-attention). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual"eager"
implementation.
Instantiates one of the model classes of the library (with a pretraining head) from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the model’s configuration. Use from_pretrained() to load the model weights.
from_pretrained
< source >( *model_args **kwargs )
Parameters
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a tensorflow index checkpoint file (e.g,
./tf_model/model.ckpt.index
). In this case,from_tf
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- state_dict (Dict[str, torch.Tensor], optional) —
A state dictionary to use instead of a state dictionary loaded from saved weights file.
This option can be used if you want to create a model from a pretrained configuration but load your own weights. In this case though, you should check if using save_pretrained() and from_pretrained() is not a simpler option.
- cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - from_tf (
bool
, optional, defaults toFalse
) — Load the model weights from a TensorFlow checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). - force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download — Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. - output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). - revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. - code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the model classes of the library (with a pretraining head) from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when it’s missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- albert —
AlbertForPreTraining
(ALBERT model) - bart — BartForConditionalGeneration (BART model)
- bert — BertForPreTraining (BERT model)
- big_bird —
BigBirdForPreTraining
(BigBird model) - bloom —
BloomForCausalLM
(BLOOM model) - camembert —
CamembertForMaskedLM
(CamemBERT model) - colpali —
ColPaliForRetrieval
(ColPali model) - ctrl —
CTRLLMHeadModel
(CTRL model) - data2vec-text —
Data2VecTextForMaskedLM
(Data2VecText model) - deberta — DebertaForMaskedLM (DeBERTa model)
- deberta-v2 — DebertaV2ForMaskedLM (DeBERTa-v2 model)
- distilbert —
DistilBertForMaskedLM
(DistilBERT model) - electra —
ElectraForPreTraining
(ELECTRA model) - ernie —
ErnieForPreTraining
(ERNIE model) - falcon_mamba —
FalconMambaForCausalLM
(FalconMamba model) - flaubert —
FlaubertWithLMHeadModel
(FlauBERT model) - flava —
FlavaForPreTraining
(FLAVA model) - fnet —
FNetForPreTraining
(FNet model) - fsmt —
FSMTForConditionalGeneration
(FairSeq Machine-Translation model) - funnel —
FunnelForPreTraining
(Funnel Transformer model) - gpt-sw3 —
GPT2LMHeadModel
(GPT-Sw3 model) - gpt2 —
GPT2LMHeadModel
(OpenAI GPT-2 model) - gpt_bigcode —
GPTBigCodeForCausalLM
(GPTBigCode model) - gptsan-japanese —
GPTSanJapaneseForConditionalGeneration
(GPTSAN-japanese model) - hiera —
HieraForPreTraining
(Hiera model) - ibert —
IBertForMaskedLM
(I-BERT model) - idefics —
IdeficsForVisionText2Text
(IDEFICS model) - idefics2 —
Idefics2ForConditionalGeneration
(Idefics2 model) - idefics3 —
Idefics3ForConditionalGeneration
(Idefics3 model) - layoutlm —
LayoutLMForMaskedLM
(LayoutLM model) - llava —
LlavaForConditionalGeneration
(LLaVa model) - llava_next —
LlavaNextForConditionalGeneration
(LLaVA-NeXT model) - llava_next_video —
LlavaNextVideoForConditionalGeneration
(LLaVa-NeXT-Video model) - llava_onevision —
LlavaOnevisionForConditionalGeneration
(LLaVA-Onevision model) - longformer —
LongformerForMaskedLM
(Longformer model) - luke —
LukeForMaskedLM
(LUKE model) - lxmert —
LxmertForPreTraining
(LXMERT model) - mamba — MambaForCausalLM (Mamba model)
- mamba2 — Mamba2ForCausalLM (mamba2 model)
- mega —
MegaForMaskedLM
(MEGA model) - megatron-bert —
MegatronBertForPreTraining
(Megatron-BERT model) - mllama —
MllamaForConditionalGeneration
(Mllama model) - mobilebert —
MobileBertForPreTraining
(MobileBERT model) - mpnet —
MPNetForMaskedLM
(MPNet model) - mpt —
MptForCausalLM
(MPT model) - mra —
MraForMaskedLM
(MRA model) - mvp —
MvpForConditionalGeneration
(MVP model) - nezha —
NezhaForPreTraining
(Nezha model) - nllb-moe —
NllbMoeForConditionalGeneration
(NLLB-MOE model) - openai-gpt — OpenAIGPTLMHeadModel (OpenAI GPT model)
- paligemma — PaliGemmaForConditionalGeneration (PaliGemma model)
- qwen2_audio —
Qwen2AudioForConditionalGeneration
(Qwen2Audio model) - retribert —
RetriBertModel
(RetriBERT model) - roberta —
RobertaForMaskedLM
(RoBERTa model) - roberta-prelayernorm —
RobertaPreLayerNormForMaskedLM
(RoBERTa-PreLayerNorm model) - roc_bert —
RoCBertForPreTraining
(RoCBert model) - rwkv —
RwkvForCausalLM
(RWKV model) - splinter —
SplinterForPreTraining
(Splinter model) - squeezebert —
SqueezeBertForMaskedLM
(SqueezeBERT model) - switch_transformers —
SwitchTransformersForConditionalGeneration
(SwitchTransformers model) - t5 —
T5ForConditionalGeneration
(T5 model) - tapas —
TapasForMaskedLM
(TAPAS model) - transfo-xl —
TransfoXLLMHeadModel
(Transformer-XL model) - tvlt —
TvltForPreTraining
(TVLT model) - unispeech —
UniSpeechForPreTraining
(UniSpeech model) - unispeech-sat —
UniSpeechSatForPreTraining
(UniSpeechSat model) - video_llava —
VideoLlavaForConditionalGeneration
(VideoLlava model) - videomae —
VideoMAEForPreTraining
(VideoMAE model) - vipllava —
VipLlavaForConditionalGeneration
(VipLlava model) - visual_bert —
VisualBertForPreTraining
(VisualBERT model) - vit_mae —
ViTMAEForPreTraining
(ViTMAE model) - wav2vec2 —
Wav2Vec2ForPreTraining
(Wav2Vec2 model) - wav2vec2-conformer —
Wav2Vec2ConformerForPreTraining
(Wav2Vec2-Conformer model) - xlm —
XLMWithLMHeadModel
(XLM model) - xlm-roberta —
XLMRobertaForMaskedLM
(XLM-RoBERTa model) - xlm-roberta-xl —
XLMRobertaXLForMaskedLM
(XLM-RoBERTa-XL model) - xlnet —
XLNetLMHeadModel
(XLNet model) - xmod —
XmodForMaskedLM
(X-MOD model)
The model is set in evaluation mode by default using model.eval()
(so for instance, dropout modules are
deactivated). To train the model, you should first set it back in training mode with model.train()
Examples:
>>> from transformers import AutoConfig, AutoModelForPreTraining
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForPreTraining.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForPreTraining.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForPreTraining.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
TFAutoModelForPreTraining
This is a generic model class that will be instantiated as one of the model classes of the library (with a pretraining head) when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
AlbertConfig
configuration class:TFAlbertForPreTraining
(ALBERT model)- BartConfig configuration class: TFBartForConditionalGeneration (BART model)
- BertConfig configuration class: TFBertForPreTraining (BERT model)
CTRLConfig
configuration class:TFCTRLLMHeadModel
(CTRL model)CamembertConfig
configuration class:TFCamembertForMaskedLM
(CamemBERT model)DistilBertConfig
configuration class:TFDistilBertForMaskedLM
(DistilBERT model)ElectraConfig
configuration class:TFElectraForPreTraining
(ELECTRA model)FlaubertConfig
configuration class:TFFlaubertWithLMHeadModel
(FlauBERT model)FunnelConfig
configuration class:TFFunnelForPreTraining
(Funnel Transformer model)GPT2Config
configuration class:TFGPT2LMHeadModel
(OpenAI GPT-2 model)IdeficsConfig
configuration class:TFIdeficsForVisionText2Text
(IDEFICS model)LayoutLMConfig
configuration class:TFLayoutLMForMaskedLM
(LayoutLM model)LxmertConfig
configuration class:TFLxmertForPreTraining
(LXMERT model)MPNetConfig
configuration class:TFMPNetForMaskedLM
(MPNet model)MobileBertConfig
configuration class:TFMobileBertForPreTraining
(MobileBERT model)- OpenAIGPTConfig configuration class: TFOpenAIGPTLMHeadModel (OpenAI GPT model)
RobertaConfig
configuration class:TFRobertaForMaskedLM
(RoBERTa model)RobertaPreLayerNormConfig
configuration class:TFRobertaPreLayerNormForMaskedLM
(RoBERTa-PreLayerNorm model)T5Config
configuration class:TFT5ForConditionalGeneration
(T5 model)TapasConfig
configuration class:TFTapasForMaskedLM
(TAPAS model)TransfoXLConfig
configuration class:TFTransfoXLLMHeadModel
(Transformer-XL model)ViTMAEConfig
configuration class:TFViTMAEForPreTraining
(ViTMAE model)XLMConfig
configuration class:TFXLMWithLMHeadModel
(XLM model)XLMRobertaConfig
configuration class:TFXLMRobertaForMaskedLM
(XLM-RoBERTa model)XLNetConfig
configuration class:TFXLNetLMHeadModel
(XLNet model)
- attn_implementation (
str
, optional) — The attention implementation to use in the model (if relevant). Can be any of"eager"
(manual implementation of the attention),"sdpa"
(usingF.scaled_dot_product_attention
), or"flash_attention_2"
(using Dao-AILab/flash-attention). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual"eager"
implementation.
Instantiates one of the model classes of the library (with a pretraining head) from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the model’s configuration. Use from_pretrained() to load the model weights.
from_pretrained
< source >( *model_args **kwargs )
Parameters
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - from_pt (
bool
, optional, defaults toFalse
) — Load the model weights from a PyTorch checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). - force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download — Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. - output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). - revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. - code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the model classes of the library (with a pretraining head) from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when it’s missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- albert —
TFAlbertForPreTraining
(ALBERT model) - bart — TFBartForConditionalGeneration (BART model)
- bert — TFBertForPreTraining (BERT model)
- camembert —
TFCamembertForMaskedLM
(CamemBERT model) - ctrl —
TFCTRLLMHeadModel
(CTRL model) - distilbert —
TFDistilBertForMaskedLM
(DistilBERT model) - electra —
TFElectraForPreTraining
(ELECTRA model) - flaubert —
TFFlaubertWithLMHeadModel
(FlauBERT model) - funnel —
TFFunnelForPreTraining
(Funnel Transformer model) - gpt-sw3 —
TFGPT2LMHeadModel
(GPT-Sw3 model) - gpt2 —
TFGPT2LMHeadModel
(OpenAI GPT-2 model) - idefics —
TFIdeficsForVisionText2Text
(IDEFICS model) - layoutlm —
TFLayoutLMForMaskedLM
(LayoutLM model) - lxmert —
TFLxmertForPreTraining
(LXMERT model) - mobilebert —
TFMobileBertForPreTraining
(MobileBERT model) - mpnet —
TFMPNetForMaskedLM
(MPNet model) - openai-gpt — TFOpenAIGPTLMHeadModel (OpenAI GPT model)
- roberta —
TFRobertaForMaskedLM
(RoBERTa model) - roberta-prelayernorm —
TFRobertaPreLayerNormForMaskedLM
(RoBERTa-PreLayerNorm model) - t5 —
TFT5ForConditionalGeneration
(T5 model) - tapas —
TFTapasForMaskedLM
(TAPAS model) - transfo-xl —
TFTransfoXLLMHeadModel
(Transformer-XL model) - vit_mae —
TFViTMAEForPreTraining
(ViTMAE model) - xlm —
TFXLMWithLMHeadModel
(XLM model) - xlm-roberta —
TFXLMRobertaForMaskedLM
(XLM-RoBERTa model) - xlnet —
TFXLNetLMHeadModel
(XLNet model)
Examples:
>>> from transformers import AutoConfig, TFAutoModelForPreTraining
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForPreTraining.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = TFAutoModelForPreTraining.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = TFAutoModelForPreTraining.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
FlaxAutoModelForPreTraining
This is a generic model class that will be instantiated as one of the model classes of the library (with a pretraining head) when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
AlbertConfig
configuration class:FlaxAlbertForPreTraining
(ALBERT model)- BartConfig configuration class: FlaxBartForConditionalGeneration (BART model)
- BertConfig configuration class: FlaxBertForPreTraining (BERT model)
BigBirdConfig
configuration class:FlaxBigBirdForPreTraining
(BigBird model)ElectraConfig
configuration class:FlaxElectraForPreTraining
(ELECTRA model)LongT5Config
configuration class:FlaxLongT5ForConditionalGeneration
(LongT5 model)MBartConfig
configuration class:FlaxMBartForConditionalGeneration
(mBART model)MT5Config
configuration class:FlaxMT5ForConditionalGeneration
(MT5 model)RoFormerConfig
configuration class:FlaxRoFormerForMaskedLM
(RoFormer model)RobertaConfig
configuration class:FlaxRobertaForMaskedLM
(RoBERTa model)RobertaPreLayerNormConfig
configuration class:FlaxRobertaPreLayerNormForMaskedLM
(RoBERTa-PreLayerNorm model)T5Config
configuration class:FlaxT5ForConditionalGeneration
(T5 model)Wav2Vec2Config
configuration class:FlaxWav2Vec2ForPreTraining
(Wav2Vec2 model)- WhisperConfig configuration class: FlaxWhisperForConditionalGeneration (Whisper model)
XLMRobertaConfig
configuration class:FlaxXLMRobertaForMaskedLM
(XLM-RoBERTa model)
- attn_implementation (
str
, optional) — The attention implementation to use in the model (if relevant). Can be any of"eager"
(manual implementation of the attention),"sdpa"
(usingF.scaled_dot_product_attention
), or"flash_attention_2"
(using Dao-AILab/flash-attention). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual"eager"
implementation.
Instantiates one of the model classes of the library (with a pretraining head) from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the model’s configuration. Use from_pretrained() to load the model weights.
from_pretrained
< source >( *model_args **kwargs )
Parameters
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - from_pt (
bool
, optional, defaults toFalse
) — Load the model weights from a PyTorch checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). - force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download — Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. - output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). - revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. - code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the model classes of the library (with a pretraining head) from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when it’s missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- albert —
FlaxAlbertForPreTraining
(ALBERT model) - bart — FlaxBartForConditionalGeneration (BART model)
- bert — FlaxBertForPreTraining (BERT model)
- big_bird —
FlaxBigBirdForPreTraining
(BigBird model) - electra —
FlaxElectraForPreTraining
(ELECTRA model) - longt5 —
FlaxLongT5ForConditionalGeneration
(LongT5 model) - mbart —
FlaxMBartForConditionalGeneration
(mBART model) - mt5 —
FlaxMT5ForConditionalGeneration
(MT5 model) - roberta —
FlaxRobertaForMaskedLM
(RoBERTa model) - roberta-prelayernorm —
FlaxRobertaPreLayerNormForMaskedLM
(RoBERTa-PreLayerNorm model) - roformer —
FlaxRoFormerForMaskedLM
(RoFormer model) - t5 —
FlaxT5ForConditionalGeneration
(T5 model) - wav2vec2 —
FlaxWav2Vec2ForPreTraining
(Wav2Vec2 model) - whisper — FlaxWhisperForConditionalGeneration (Whisper model)
- xlm-roberta —
FlaxXLMRobertaForMaskedLM
(XLM-RoBERTa model)
Examples:
>>> from transformers import AutoConfig, FlaxAutoModelForPreTraining
>>> # Download model and configuration from huggingface.co and cache.
>>> model = FlaxAutoModelForPreTraining.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = FlaxAutoModelForPreTraining.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = FlaxAutoModelForPreTraining.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
자연어 처리
다음 자동 클래스들은 아래의 자연어 처리 작업에 사용할 수 있습니다.
AutoModelForCausalLM
This is a generic model class that will be instantiated as one of the model classes of the library (with a causal language modeling head) when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
AriaTextConfig
configuration class:AriaTextForCausalLM
(AriaText model)BambaConfig
configuration class:BambaForCausalLM
(Bamba model)- BartConfig configuration class: BartForCausalLM (BART model)
- BertConfig configuration class: BertLMHeadModel (BERT model)
BertGenerationConfig
configuration class:BertGenerationDecoder
(Bert Generation model)BigBirdConfig
configuration class:BigBirdForCausalLM
(BigBird model)BigBirdPegasusConfig
configuration class:BigBirdPegasusForCausalLM
(BigBird-Pegasus model)- BioGptConfig configuration class: BioGptForCausalLM (BioGpt model)
BlenderbotConfig
configuration class:BlenderbotForCausalLM
(Blenderbot model)BlenderbotSmallConfig
configuration class:BlenderbotSmallForCausalLM
(BlenderbotSmall model)BloomConfig
configuration class:BloomForCausalLM
(BLOOM model)CTRLConfig
configuration class:CTRLLMHeadModel
(CTRL model)CamembertConfig
configuration class:CamembertForCausalLM
(CamemBERT model)CodeGenConfig
configuration class:CodeGenForCausalLM
(CodeGen model)Cohere2Config
configuration class:Cohere2ForCausalLM
(Cohere2 model)- CohereConfig configuration class: CohereForCausalLM (Cohere model)
CpmAntConfig
configuration class:CpmAntForCausalLM
(CPM-Ant model)Data2VecTextConfig
configuration class:Data2VecTextForCausalLM
(Data2VecText model)- DbrxConfig configuration class: DbrxForCausalLM (DBRX model)
DiffLlamaConfig
configuration class:DiffLlamaForCausalLM
(DiffLlama model)ElectraConfig
configuration class:ElectraForCausalLM
(ELECTRA model)Emu3Config
configuration class:Emu3ForCausalLM
(Emu3 model)ErnieConfig
configuration class:ErnieForCausalLM
(ERNIE model)FalconConfig
configuration class:FalconForCausalLM
(Falcon model)FalconMambaConfig
configuration class:FalconMambaForCausalLM
(FalconMamba model)FuyuConfig
configuration class:FuyuForCausalLM
(Fuyu model)GPT2Config
configuration class:GPT2LMHeadModel
(OpenAI GPT-2 model)GPTBigCodeConfig
configuration class:GPTBigCodeForCausalLM
(GPTBigCode model)GPTJConfig
configuration class:GPTJForCausalLM
(GPT-J model)GPTNeoConfig
configuration class:GPTNeoForCausalLM
(GPT Neo model)GPTNeoXConfig
configuration class:GPTNeoXForCausalLM
(GPT NeoX model)- GPTNeoXJapaneseConfig configuration class: GPTNeoXJapaneseForCausalLM (GPT NeoX Japanese model)
- Gemma2Config configuration class: Gemma2ForCausalLM (Gemma2 model)
- GemmaConfig configuration class: GemmaForCausalLM (Gemma model)
GitConfig
configuration class:GitForCausalLM
(GIT model)GlmConfig
configuration class:GlmForCausalLM
(GLM model)GraniteConfig
configuration class:GraniteForCausalLM
(Granite model)GraniteMoeConfig
configuration class:GraniteMoeForCausalLM
(GraniteMoeMoe model)HeliumConfig
configuration class:HeliumForCausalLM
(Helium model)JambaConfig
configuration class:JambaForCausalLM
(Jamba model)JetMoeConfig
configuration class:JetMoeForCausalLM
(JetMoe model)- LlamaConfig configuration class: LlamaForCausalLM (LLaMA model)
MBartConfig
configuration class:MBartForCausalLM
(mBART model)- Mamba2Config configuration class: Mamba2ForCausalLM (mamba2 model)
- MambaConfig configuration class: MambaForCausalLM (Mamba model)
- MarianConfig configuration class: MarianForCausalLM (Marian model)
MegaConfig
configuration class:MegaForCausalLM
(MEGA model)MegatronBertConfig
configuration class:MegatronBertForCausalLM
(Megatron-BERT model)- MistralConfig configuration class: MistralForCausalLM (Mistral model)
MixtralConfig
configuration class:MixtralForCausalLM
(Mixtral model)MllamaConfig
configuration class:MllamaForCausalLM
(Mllama model)MoshiConfig
configuration class:MoshiForCausalLM
(Moshi model)MptConfig
configuration class:MptForCausalLM
(MPT model)MusicgenConfig
configuration class:MusicgenForCausalLM
(MusicGen model)MusicgenMelodyConfig
configuration class:MusicgenMelodyForCausalLM
(MusicGen Melody model)MvpConfig
configuration class:MvpForCausalLM
(MVP model)NemotronConfig
configuration class:NemotronForCausalLM
(Nemotron model)OPTConfig
configuration class:OPTForCausalLM
(OPT model)Olmo2Config
configuration class:Olmo2ForCausalLM
(OLMo2 model)OlmoConfig
configuration class:OlmoForCausalLM
(OLMo model)OlmoeConfig
configuration class:OlmoeForCausalLM
(OLMoE model)- OpenAIGPTConfig configuration class: OpenAIGPTLMHeadModel (OpenAI GPT model)
OpenLlamaConfig
configuration class:OpenLlamaForCausalLM
(OpenLlama model)PLBartConfig
configuration class:PLBartForCausalLM
(PLBart model)PegasusConfig
configuration class:PegasusForCausalLM
(Pegasus model)PersimmonConfig
configuration class:PersimmonForCausalLM
(Persimmon model)Phi3Config
configuration class:Phi3ForCausalLM
(Phi3 model)PhiConfig
configuration class:PhiForCausalLM
(Phi model)PhimoeConfig
configuration class:PhimoeForCausalLM
(Phimoe model)ProphetNetConfig
configuration class:ProphetNetForCausalLM
(ProphetNet model)QDQBertConfig
configuration class:QDQBertLMHeadModel
(QDQBert model)Qwen2Config
configuration class:Qwen2ForCausalLM
(Qwen2 model)Qwen2MoeConfig
configuration class:Qwen2MoeForCausalLM
(Qwen2MoE model)RecurrentGemmaConfig
configuration class:RecurrentGemmaForCausalLM
(RecurrentGemma model)ReformerConfig
configuration class:ReformerModelWithLMHead
(Reformer model)RemBertConfig
configuration class:RemBertForCausalLM
(RemBERT model)RoCBertConfig
configuration class:RoCBertForCausalLM
(RoCBert model)RoFormerConfig
configuration class:RoFormerForCausalLM
(RoFormer model)RobertaConfig
configuration class:RobertaForCausalLM
(RoBERTa model)RobertaPreLayerNormConfig
configuration class:RobertaPreLayerNormForCausalLM
(RoBERTa-PreLayerNorm model)RwkvConfig
configuration class:RwkvForCausalLM
(RWKV model)Speech2Text2Config
configuration class:Speech2Text2ForCausalLM
(Speech2Text2 model)StableLmConfig
configuration class:StableLmForCausalLM
(StableLm model)Starcoder2Config
configuration class:Starcoder2ForCausalLM
(Starcoder2 model)TrOCRConfig
configuration class:TrOCRForCausalLM
(TrOCR model)TransfoXLConfig
configuration class:TransfoXLLMHeadModel
(Transformer-XL model)- WhisperConfig configuration class:
WhisperForCausalLM
(Whisper model) XGLMConfig
configuration class:XGLMForCausalLM
(XGLM model)XLMConfig
configuration class:XLMWithLMHeadModel
(XLM model)XLMProphetNetConfig
configuration class:XLMProphetNetForCausalLM
(XLM-ProphetNet model)XLMRobertaConfig
configuration class:XLMRobertaForCausalLM
(XLM-RoBERTa model)XLMRobertaXLConfig
configuration class:XLMRobertaXLForCausalLM
(XLM-RoBERTa-XL model)XLNetConfig
configuration class:XLNetLMHeadModel
(XLNet model)XmodConfig
configuration class:XmodForCausalLM
(X-MOD model)ZambaConfig
configuration class:ZambaForCausalLM
(Zamba model)
- attn_implementation (
str
, optional) — The attention implementation to use in the model (if relevant). Can be any of"eager"
(manual implementation of the attention),"sdpa"
(usingF.scaled_dot_product_attention
), or"flash_attention_2"
(using Dao-AILab/flash-attention). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual"eager"
implementation.
Instantiates one of the model classes of the library (with a causal language modeling head) from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the model’s configuration. Use from_pretrained() to load the model weights.
from_pretrained
< source >( *model_args **kwargs )
Parameters
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a tensorflow index checkpoint file (e.g,
./tf_model/model.ckpt.index
). In this case,from_tf
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- state_dict (Dict[str, torch.Tensor], optional) —
A state dictionary to use instead of a state dictionary loaded from saved weights file.
This option can be used if you want to create a model from a pretrained configuration but load your own weights. In this case though, you should check if using save_pretrained() and from_pretrained() is not a simpler option.
- cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - from_tf (
bool
, optional, defaults toFalse
) — Load the model weights from a TensorFlow checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). - force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download — Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. - output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). - revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. - code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the model classes of the library (with a causal language modeling head) from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when it’s missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- aria_text —
AriaTextForCausalLM
(AriaText model) - bamba —
BambaForCausalLM
(Bamba model) - bart — BartForCausalLM (BART model)
- bert — BertLMHeadModel (BERT model)
- bert-generation —
BertGenerationDecoder
(Bert Generation model) - big_bird —
BigBirdForCausalLM
(BigBird model) - bigbird_pegasus —
BigBirdPegasusForCausalLM
(BigBird-Pegasus model) - biogpt — BioGptForCausalLM (BioGpt model)
- blenderbot —
BlenderbotForCausalLM
(Blenderbot model) - blenderbot-small —
BlenderbotSmallForCausalLM
(BlenderbotSmall model) - bloom —
BloomForCausalLM
(BLOOM model) - camembert —
CamembertForCausalLM
(CamemBERT model) - code_llama — LlamaForCausalLM (CodeLlama model)
- codegen —
CodeGenForCausalLM
(CodeGen model) - cohere — CohereForCausalLM (Cohere model)
- cohere2 —
Cohere2ForCausalLM
(Cohere2 model) - cpmant —
CpmAntForCausalLM
(CPM-Ant model) - ctrl —
CTRLLMHeadModel
(CTRL model) - data2vec-text —
Data2VecTextForCausalLM
(Data2VecText model) - dbrx — DbrxForCausalLM (DBRX model)
- diffllama —
DiffLlamaForCausalLM
(DiffLlama model) - electra —
ElectraForCausalLM
(ELECTRA model) - emu3 —
Emu3ForCausalLM
(Emu3 model) - ernie —
ErnieForCausalLM
(ERNIE model) - falcon —
FalconForCausalLM
(Falcon model) - falcon_mamba —
FalconMambaForCausalLM
(FalconMamba model) - fuyu —
FuyuForCausalLM
(Fuyu model) - gemma — GemmaForCausalLM (Gemma model)
- gemma2 — Gemma2ForCausalLM (Gemma2 model)
- git —
GitForCausalLM
(GIT model) - glm —
GlmForCausalLM
(GLM model) - gpt-sw3 —
GPT2LMHeadModel
(GPT-Sw3 model) - gpt2 —
GPT2LMHeadModel
(OpenAI GPT-2 model) - gpt_bigcode —
GPTBigCodeForCausalLM
(GPTBigCode model) - gpt_neo —
GPTNeoForCausalLM
(GPT Neo model) - gpt_neox —
GPTNeoXForCausalLM
(GPT NeoX model) - gpt_neox_japanese — GPTNeoXJapaneseForCausalLM (GPT NeoX Japanese model)
- gptj —
GPTJForCausalLM
(GPT-J model) - granite —
GraniteForCausalLM
(Granite model) - granitemoe —
GraniteMoeForCausalLM
(GraniteMoeMoe model) - helium —
HeliumForCausalLM
(Helium model) - jamba —
JambaForCausalLM
(Jamba model) - jetmoe —
JetMoeForCausalLM
(JetMoe model) - llama — LlamaForCausalLM (LLaMA model)
- mamba — MambaForCausalLM (Mamba model)
- mamba2 — Mamba2ForCausalLM (mamba2 model)
- marian — MarianForCausalLM (Marian model)
- mbart —
MBartForCausalLM
(mBART model) - mega —
MegaForCausalLM
(MEGA model) - megatron-bert —
MegatronBertForCausalLM
(Megatron-BERT model) - mistral — MistralForCausalLM (Mistral model)
- mixtral —
MixtralForCausalLM
(Mixtral model) - mllama —
MllamaForCausalLM
(Mllama model) - moshi —
MoshiForCausalLM
(Moshi model) - mpt —
MptForCausalLM
(MPT model) - musicgen —
MusicgenForCausalLM
(MusicGen model) - musicgen_melody —
MusicgenMelodyForCausalLM
(MusicGen Melody model) - mvp —
MvpForCausalLM
(MVP model) - nemotron —
NemotronForCausalLM
(Nemotron model) - olmo —
OlmoForCausalLM
(OLMo model) - olmo2 —
Olmo2ForCausalLM
(OLMo2 model) - olmoe —
OlmoeForCausalLM
(OLMoE model) - open-llama —
OpenLlamaForCausalLM
(OpenLlama model) - openai-gpt — OpenAIGPTLMHeadModel (OpenAI GPT model)
- opt —
OPTForCausalLM
(OPT model) - pegasus —
PegasusForCausalLM
(Pegasus model) - persimmon —
PersimmonForCausalLM
(Persimmon model) - phi —
PhiForCausalLM
(Phi model) - phi3 —
Phi3ForCausalLM
(Phi3 model) - phimoe —
PhimoeForCausalLM
(Phimoe model) - plbart —
PLBartForCausalLM
(PLBart model) - prophetnet —
ProphetNetForCausalLM
(ProphetNet model) - qdqbert —
QDQBertLMHeadModel
(QDQBert model) - qwen2 —
Qwen2ForCausalLM
(Qwen2 model) - qwen2_moe —
Qwen2MoeForCausalLM
(Qwen2MoE model) - recurrent_gemma —
RecurrentGemmaForCausalLM
(RecurrentGemma model) - reformer —
ReformerModelWithLMHead
(Reformer model) - rembert —
RemBertForCausalLM
(RemBERT model) - roberta —
RobertaForCausalLM
(RoBERTa model) - roberta-prelayernorm —
RobertaPreLayerNormForCausalLM
(RoBERTa-PreLayerNorm model) - roc_bert —
RoCBertForCausalLM
(RoCBert model) - roformer —
RoFormerForCausalLM
(RoFormer model) - rwkv —
RwkvForCausalLM
(RWKV model) - speech_to_text_2 —
Speech2Text2ForCausalLM
(Speech2Text2 model) - stablelm —
StableLmForCausalLM
(StableLm model) - starcoder2 —
Starcoder2ForCausalLM
(Starcoder2 model) - transfo-xl —
TransfoXLLMHeadModel
(Transformer-XL model) - trocr —
TrOCRForCausalLM
(TrOCR model) - whisper —
WhisperForCausalLM
(Whisper model) - xglm —
XGLMForCausalLM
(XGLM model) - xlm —
XLMWithLMHeadModel
(XLM model) - xlm-prophetnet —
XLMProphetNetForCausalLM
(XLM-ProphetNet model) - xlm-roberta —
XLMRobertaForCausalLM
(XLM-RoBERTa model) - xlm-roberta-xl —
XLMRobertaXLForCausalLM
(XLM-RoBERTa-XL model) - xlnet —
XLNetLMHeadModel
(XLNet model) - xmod —
XmodForCausalLM
(X-MOD model) - zamba —
ZambaForCausalLM
(Zamba model)
The model is set in evaluation mode by default using model.eval()
(so for instance, dropout modules are
deactivated). To train the model, you should first set it back in training mode with model.train()
Examples:
>>> from transformers import AutoConfig, AutoModelForCausalLM
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForCausalLM.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForCausalLM.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForCausalLM.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
TFAutoModelForCausalLM
This is a generic model class that will be instantiated as one of the model classes of the library (with a causal language modeling head) when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- BertConfig configuration class: TFBertLMHeadModel (BERT model)
CTRLConfig
configuration class:TFCTRLLMHeadModel
(CTRL model)CamembertConfig
configuration class:TFCamembertForCausalLM
(CamemBERT model)GPT2Config
configuration class:TFGPT2LMHeadModel
(OpenAI GPT-2 model)GPTJConfig
configuration class:TFGPTJForCausalLM
(GPT-J model)- MistralConfig configuration class: TFMistralForCausalLM (Mistral model)
OPTConfig
configuration class:TFOPTForCausalLM
(OPT model)- OpenAIGPTConfig configuration class: TFOpenAIGPTLMHeadModel (OpenAI GPT model)
RemBertConfig
configuration class:TFRemBertForCausalLM
(RemBERT model)RoFormerConfig
configuration class:TFRoFormerForCausalLM
(RoFormer model)RobertaConfig
configuration class:TFRobertaForCausalLM
(RoBERTa model)RobertaPreLayerNormConfig
configuration class:TFRobertaPreLayerNormForCausalLM
(RoBERTa-PreLayerNorm model)TransfoXLConfig
configuration class:TFTransfoXLLMHeadModel
(Transformer-XL model)XGLMConfig
configuration class:TFXGLMForCausalLM
(XGLM model)XLMConfig
configuration class:TFXLMWithLMHeadModel
(XLM model)XLMRobertaConfig
configuration class:TFXLMRobertaForCausalLM
(XLM-RoBERTa model)XLNetConfig
configuration class:TFXLNetLMHeadModel
(XLNet model)
- attn_implementation (
str
, optional) — The attention implementation to use in the model (if relevant). Can be any of"eager"
(manual implementation of the attention),"sdpa"
(usingF.scaled_dot_product_attention
), or"flash_attention_2"
(using Dao-AILab/flash-attention). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual"eager"
implementation.
Instantiates one of the model classes of the library (with a causal language modeling head) from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the model’s configuration. Use from_pretrained() to load the model weights.
from_pretrained
< source >( *model_args **kwargs )
Parameters
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - from_pt (
bool
, optional, defaults toFalse
) — Load the model weights from a PyTorch checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). - force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download — Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. - output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). - revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. - code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the model classes of the library (with a causal language modeling head) from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when it’s missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- bert — TFBertLMHeadModel (BERT model)
- camembert —
TFCamembertForCausalLM
(CamemBERT model) - ctrl —
TFCTRLLMHeadModel
(CTRL model) - gpt-sw3 —
TFGPT2LMHeadModel
(GPT-Sw3 model) - gpt2 —
TFGPT2LMHeadModel
(OpenAI GPT-2 model) - gptj —
TFGPTJForCausalLM
(GPT-J model) - mistral — TFMistralForCausalLM (Mistral model)
- openai-gpt — TFOpenAIGPTLMHeadModel (OpenAI GPT model)
- opt —
TFOPTForCausalLM
(OPT model) - rembert —
TFRemBertForCausalLM
(RemBERT model) - roberta —
TFRobertaForCausalLM
(RoBERTa model) - roberta-prelayernorm —
TFRobertaPreLayerNormForCausalLM
(RoBERTa-PreLayerNorm model) - roformer —
TFRoFormerForCausalLM
(RoFormer model) - transfo-xl —
TFTransfoXLLMHeadModel
(Transformer-XL model) - xglm —
TFXGLMForCausalLM
(XGLM model) - xlm —
TFXLMWithLMHeadModel
(XLM model) - xlm-roberta —
TFXLMRobertaForCausalLM
(XLM-RoBERTa model) - xlnet —
TFXLNetLMHeadModel
(XLNet model)
Examples:
>>> from transformers import AutoConfig, TFAutoModelForCausalLM
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForCausalLM.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = TFAutoModelForCausalLM.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = TFAutoModelForCausalLM.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
FlaxAutoModelForCausalLM
This is a generic model class that will be instantiated as one of the model classes of the library (with a causal language modeling head) when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- BartConfig configuration class: FlaxBartForCausalLM (BART model)
- BertConfig configuration class: FlaxBertForCausalLM (BERT model)
BigBirdConfig
configuration class:FlaxBigBirdForCausalLM
(BigBird model)BloomConfig
configuration class:FlaxBloomForCausalLM
(BLOOM model)ElectraConfig
configuration class:FlaxElectraForCausalLM
(ELECTRA model)GPT2Config
configuration class:FlaxGPT2LMHeadModel
(OpenAI GPT-2 model)GPTJConfig
configuration class:FlaxGPTJForCausalLM
(GPT-J model)GPTNeoConfig
configuration class:FlaxGPTNeoForCausalLM
(GPT Neo model)- GemmaConfig configuration class: FlaxGemmaForCausalLM (Gemma model)
- LlamaConfig configuration class:
FlaxLlamaForCausalLM
(LLaMA model) - MistralConfig configuration class: FlaxMistralForCausalLM (Mistral model)
OPTConfig
configuration class:FlaxOPTForCausalLM
(OPT model)RobertaConfig
configuration class:FlaxRobertaForCausalLM
(RoBERTa model)RobertaPreLayerNormConfig
configuration class:FlaxRobertaPreLayerNormForCausalLM
(RoBERTa-PreLayerNorm model)XGLMConfig
configuration class:FlaxXGLMForCausalLM
(XGLM model)XLMRobertaConfig
configuration class:FlaxXLMRobertaForCausalLM
(XLM-RoBERTa model)
- attn_implementation (
str
, optional) — The attention implementation to use in the model (if relevant). Can be any of"eager"
(manual implementation of the attention),"sdpa"
(usingF.scaled_dot_product_attention
), or"flash_attention_2"
(using Dao-AILab/flash-attention). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual"eager"
implementation.
Instantiates one of the model classes of the library (with a causal language modeling head) from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the model’s configuration. Use from_pretrained() to load the model weights.
from_pretrained
< source >( *model_args **kwargs )
Parameters
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - from_pt (
bool
, optional, defaults toFalse
) — Load the model weights from a PyTorch checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). - force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download — Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. - output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). - revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. - code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the model classes of the library (with a causal language modeling head) from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when it’s missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- bart — FlaxBartForCausalLM (BART model)
- bert — FlaxBertForCausalLM (BERT model)
- big_bird —
FlaxBigBirdForCausalLM
(BigBird model) - bloom —
FlaxBloomForCausalLM
(BLOOM model) - electra —
FlaxElectraForCausalLM
(ELECTRA model) - gemma — FlaxGemmaForCausalLM (Gemma model)
- gpt-sw3 —
FlaxGPT2LMHeadModel
(GPT-Sw3 model) - gpt2 —
FlaxGPT2LMHeadModel
(OpenAI GPT-2 model) - gpt_neo —
FlaxGPTNeoForCausalLM
(GPT Neo model) - gptj —
FlaxGPTJForCausalLM
(GPT-J model) - llama —
FlaxLlamaForCausalLM
(LLaMA model) - mistral — FlaxMistralForCausalLM (Mistral model)
- opt —
FlaxOPTForCausalLM
(OPT model) - roberta —
FlaxRobertaForCausalLM
(RoBERTa model) - roberta-prelayernorm —
FlaxRobertaPreLayerNormForCausalLM
(RoBERTa-PreLayerNorm model) - xglm —
FlaxXGLMForCausalLM
(XGLM model) - xlm-roberta —
FlaxXLMRobertaForCausalLM
(XLM-RoBERTa model)
Examples:
>>> from transformers import AutoConfig, FlaxAutoModelForCausalLM
>>> # Download model and configuration from huggingface.co and cache.
>>> model = FlaxAutoModelForCausalLM.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = FlaxAutoModelForCausalLM.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = FlaxAutoModelForCausalLM.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
AutoModelForMaskedLM
This is a generic model class that will be instantiated as one of the model classes of the library (with a masked language modeling head) when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
AlbertConfig
configuration class:AlbertForMaskedLM
(ALBERT model)- BartConfig configuration class: BartForConditionalGeneration (BART model)
- BertConfig configuration class: BertForMaskedLM (BERT model)
BigBirdConfig
configuration class:BigBirdForMaskedLM
(BigBird model)CamembertConfig
configuration class:CamembertForMaskedLM
(CamemBERT model)- ConvBertConfig configuration class: ConvBertForMaskedLM (ConvBERT model)
Data2VecTextConfig
configuration class:Data2VecTextForMaskedLM
(Data2VecText model)- DebertaConfig configuration class: DebertaForMaskedLM (DeBERTa model)
- DebertaV2Config configuration class: DebertaV2ForMaskedLM (DeBERTa-v2 model)
DistilBertConfig
configuration class:DistilBertForMaskedLM
(DistilBERT model)ElectraConfig
configuration class:ElectraForMaskedLM
(ELECTRA model)ErnieConfig
configuration class:ErnieForMaskedLM
(ERNIE model)- EsmConfig configuration class: EsmForMaskedLM (ESM model)
FNetConfig
configuration class:FNetForMaskedLM
(FNet model)FlaubertConfig
configuration class:FlaubertWithLMHeadModel
(FlauBERT model)FunnelConfig
configuration class:FunnelForMaskedLM
(Funnel Transformer model)IBertConfig
configuration class:IBertForMaskedLM
(I-BERT model)LayoutLMConfig
configuration class:LayoutLMForMaskedLM
(LayoutLM model)LongformerConfig
configuration class:LongformerForMaskedLM
(Longformer model)LukeConfig
configuration class:LukeForMaskedLM
(LUKE model)MBartConfig
configuration class:MBartForConditionalGeneration
(mBART model)MPNetConfig
configuration class:MPNetForMaskedLM
(MPNet model)MegaConfig
configuration class:MegaForMaskedLM
(MEGA model)MegatronBertConfig
configuration class:MegatronBertForMaskedLM
(Megatron-BERT model)MobileBertConfig
configuration class:MobileBertForMaskedLM
(MobileBERT model)ModernBertConfig
configuration class:ModernBertForMaskedLM
(ModernBERT model)MraConfig
configuration class:MraForMaskedLM
(MRA model)MvpConfig
configuration class:MvpForConditionalGeneration
(MVP model)NezhaConfig
configuration class:NezhaForMaskedLM
(Nezha model)NystromformerConfig
configuration class:NystromformerForMaskedLM
(Nyströmformer model)PerceiverConfig
configuration class:PerceiverForMaskedLM
(Perceiver model)QDQBertConfig
configuration class:QDQBertForMaskedLM
(QDQBert model)ReformerConfig
configuration class:ReformerForMaskedLM
(Reformer model)RemBertConfig
configuration class:RemBertForMaskedLM
(RemBERT model)RoCBertConfig
configuration class:RoCBertForMaskedLM
(RoCBert model)RoFormerConfig
configuration class:RoFormerForMaskedLM
(RoFormer model)RobertaConfig
configuration class:RobertaForMaskedLM
(RoBERTa model)RobertaPreLayerNormConfig
configuration class:RobertaPreLayerNormForMaskedLM
(RoBERTa-PreLayerNorm model)SqueezeBertConfig
configuration class:SqueezeBertForMaskedLM
(SqueezeBERT model)TapasConfig
configuration class:TapasForMaskedLM
(TAPAS model)Wav2Vec2Config
configuration class:Wav2Vec2ForMaskedLM
(Wav2Vec2 model)XLMConfig
configuration class:XLMWithLMHeadModel
(XLM model)XLMRobertaConfig
configuration class:XLMRobertaForMaskedLM
(XLM-RoBERTa model)XLMRobertaXLConfig
configuration class:XLMRobertaXLForMaskedLM
(XLM-RoBERTa-XL model)XmodConfig
configuration class:XmodForMaskedLM
(X-MOD model)YosoConfig
configuration class:YosoForMaskedLM
(YOSO model)
- attn_implementation (
str
, optional) — The attention implementation to use in the model (if relevant). Can be any of"eager"
(manual implementation of the attention),"sdpa"
(usingF.scaled_dot_product_attention
), or"flash_attention_2"
(using Dao-AILab/flash-attention). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual"eager"
implementation.
Instantiates one of the model classes of the library (with a masked language modeling head) from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the model’s configuration. Use from_pretrained() to load the model weights.
from_pretrained
< source >( *model_args **kwargs )
Parameters
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a tensorflow index checkpoint file (e.g,
./tf_model/model.ckpt.index
). In this case,from_tf
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- state_dict (Dict[str, torch.Tensor], optional) —
A state dictionary to use instead of a state dictionary loaded from saved weights file.
This option can be used if you want to create a model from a pretrained configuration but load your own weights. In this case though, you should check if using save_pretrained() and from_pretrained() is not a simpler option.
- cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - from_tf (
bool
, optional, defaults toFalse
) — Load the model weights from a TensorFlow checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). - force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download — Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. - output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). - revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. - code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the model classes of the library (with a masked language modeling head) from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when it’s missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- albert —
AlbertForMaskedLM
(ALBERT model) - bart — BartForConditionalGeneration (BART model)
- bert — BertForMaskedLM (BERT model)
- big_bird —
BigBirdForMaskedLM
(BigBird model) - camembert —
CamembertForMaskedLM
(CamemBERT model) - convbert — ConvBertForMaskedLM (ConvBERT model)
- data2vec-text —
Data2VecTextForMaskedLM
(Data2VecText model) - deberta — DebertaForMaskedLM (DeBERTa model)
- deberta-v2 — DebertaV2ForMaskedLM (DeBERTa-v2 model)
- distilbert —
DistilBertForMaskedLM
(DistilBERT model) - electra —
ElectraForMaskedLM
(ELECTRA model) - ernie —
ErnieForMaskedLM
(ERNIE model) - esm — EsmForMaskedLM (ESM model)
- flaubert —
FlaubertWithLMHeadModel
(FlauBERT model) - fnet —
FNetForMaskedLM
(FNet model) - funnel —
FunnelForMaskedLM
(Funnel Transformer model) - ibert —
IBertForMaskedLM
(I-BERT model) - layoutlm —
LayoutLMForMaskedLM
(LayoutLM model) - longformer —
LongformerForMaskedLM
(Longformer model) - luke —
LukeForMaskedLM
(LUKE model) - mbart —
MBartForConditionalGeneration
(mBART model) - mega —
MegaForMaskedLM
(MEGA model) - megatron-bert —
MegatronBertForMaskedLM
(Megatron-BERT model) - mobilebert —
MobileBertForMaskedLM
(MobileBERT model) - modernbert —
ModernBertForMaskedLM
(ModernBERT model) - mpnet —
MPNetForMaskedLM
(MPNet model) - mra —
MraForMaskedLM
(MRA model) - mvp —
MvpForConditionalGeneration
(MVP model) - nezha —
NezhaForMaskedLM
(Nezha model) - nystromformer —
NystromformerForMaskedLM
(Nyströmformer model) - perceiver —
PerceiverForMaskedLM
(Perceiver model) - qdqbert —
QDQBertForMaskedLM
(QDQBert model) - reformer —
ReformerForMaskedLM
(Reformer model) - rembert —
RemBertForMaskedLM
(RemBERT model) - roberta —
RobertaForMaskedLM
(RoBERTa model) - roberta-prelayernorm —
RobertaPreLayerNormForMaskedLM
(RoBERTa-PreLayerNorm model) - roc_bert —
RoCBertForMaskedLM
(RoCBert model) - roformer —
RoFormerForMaskedLM
(RoFormer model) - squeezebert —
SqueezeBertForMaskedLM
(SqueezeBERT model) - tapas —
TapasForMaskedLM
(TAPAS model) - wav2vec2 —
Wav2Vec2ForMaskedLM
(Wav2Vec2 model) - xlm —
XLMWithLMHeadModel
(XLM model) - xlm-roberta —
XLMRobertaForMaskedLM
(XLM-RoBERTa model) - xlm-roberta-xl —
XLMRobertaXLForMaskedLM
(XLM-RoBERTa-XL model) - xmod —
XmodForMaskedLM
(X-MOD model) - yoso —
YosoForMaskedLM
(YOSO model)
The model is set in evaluation mode by default using model.eval()
(so for instance, dropout modules are
deactivated). To train the model, you should first set it back in training mode with model.train()
Examples:
>>> from transformers import AutoConfig, AutoModelForMaskedLM
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForMaskedLM.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForMaskedLM.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForMaskedLM.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
TFAutoModelForMaskedLM
This is a generic model class that will be instantiated as one of the model classes of the library (with a masked language modeling head) when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
AlbertConfig
configuration class:TFAlbertForMaskedLM
(ALBERT model)- BertConfig configuration class: TFBertForMaskedLM (BERT model)
CamembertConfig
configuration class:TFCamembertForMaskedLM
(CamemBERT model)- ConvBertConfig configuration class: TFConvBertForMaskedLM (ConvBERT model)
- DebertaConfig configuration class: TFDebertaForMaskedLM (DeBERTa model)
- DebertaV2Config configuration class: TFDebertaV2ForMaskedLM (DeBERTa-v2 model)
DistilBertConfig
configuration class:TFDistilBertForMaskedLM
(DistilBERT model)ElectraConfig
configuration class:TFElectraForMaskedLM
(ELECTRA model)- EsmConfig configuration class: TFEsmForMaskedLM (ESM model)
FlaubertConfig
configuration class:TFFlaubertWithLMHeadModel
(FlauBERT model)FunnelConfig
configuration class:TFFunnelForMaskedLM
(Funnel Transformer model)LayoutLMConfig
configuration class:TFLayoutLMForMaskedLM
(LayoutLM model)LongformerConfig
configuration class:TFLongformerForMaskedLM
(Longformer model)MPNetConfig
configuration class:TFMPNetForMaskedLM
(MPNet model)MobileBertConfig
configuration class:TFMobileBertForMaskedLM
(MobileBERT model)RemBertConfig
configuration class:TFRemBertForMaskedLM
(RemBERT model)RoFormerConfig
configuration class:TFRoFormerForMaskedLM
(RoFormer model)RobertaConfig
configuration class:TFRobertaForMaskedLM
(RoBERTa model)RobertaPreLayerNormConfig
configuration class:TFRobertaPreLayerNormForMaskedLM
(RoBERTa-PreLayerNorm model)TapasConfig
configuration class:TFTapasForMaskedLM
(TAPAS model)XLMConfig
configuration class:TFXLMWithLMHeadModel
(XLM model)XLMRobertaConfig
configuration class:TFXLMRobertaForMaskedLM
(XLM-RoBERTa model)
- attn_implementation (
str
, optional) — The attention implementation to use in the model (if relevant). Can be any of"eager"
(manual implementation of the attention),"sdpa"
(usingF.scaled_dot_product_attention
), or"flash_attention_2"
(using Dao-AILab/flash-attention). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual"eager"
implementation.
Instantiates one of the model classes of the library (with a masked language modeling head) from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the model’s configuration. Use from_pretrained() to load the model weights.
from_pretrained
< source >( *model_args **kwargs )
Parameters
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - from_pt (
bool
, optional, defaults toFalse
) — Load the model weights from a PyTorch checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). - force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download — Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. - output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). - revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. - code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the model classes of the library (with a masked language modeling head) from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when it’s missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- albert —
TFAlbertForMaskedLM
(ALBERT model) - bert — TFBertForMaskedLM (BERT model)
- camembert —
TFCamembertForMaskedLM
(CamemBERT model) - convbert — TFConvBertForMaskedLM (ConvBERT model)
- deberta — TFDebertaForMaskedLM (DeBERTa model)
- deberta-v2 — TFDebertaV2ForMaskedLM (DeBERTa-v2 model)
- distilbert —
TFDistilBertForMaskedLM
(DistilBERT model) - electra —
TFElectraForMaskedLM
(ELECTRA model) - esm — TFEsmForMaskedLM (ESM model)
- flaubert —
TFFlaubertWithLMHeadModel
(FlauBERT model) - funnel —
TFFunnelForMaskedLM
(Funnel Transformer model) - layoutlm —
TFLayoutLMForMaskedLM
(LayoutLM model) - longformer —
TFLongformerForMaskedLM
(Longformer model) - mobilebert —
TFMobileBertForMaskedLM
(MobileBERT model) - mpnet —
TFMPNetForMaskedLM
(MPNet model) - rembert —
TFRemBertForMaskedLM
(RemBERT model) - roberta —
TFRobertaForMaskedLM
(RoBERTa model) - roberta-prelayernorm —
TFRobertaPreLayerNormForMaskedLM
(RoBERTa-PreLayerNorm model) - roformer —
TFRoFormerForMaskedLM
(RoFormer model) - tapas —
TFTapasForMaskedLM
(TAPAS model) - xlm —
TFXLMWithLMHeadModel
(XLM model) - xlm-roberta —
TFXLMRobertaForMaskedLM
(XLM-RoBERTa model)
Examples:
>>> from transformers import AutoConfig, TFAutoModelForMaskedLM
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForMaskedLM.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = TFAutoModelForMaskedLM.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = TFAutoModelForMaskedLM.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
FlaxAutoModelForMaskedLM
This is a generic model class that will be instantiated as one of the model classes of the library (with a masked language modeling head) when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
AlbertConfig
configuration class:FlaxAlbertForMaskedLM
(ALBERT model)- BartConfig configuration class: FlaxBartForConditionalGeneration (BART model)
- BertConfig configuration class: FlaxBertForMaskedLM (BERT model)
BigBirdConfig
configuration class:FlaxBigBirdForMaskedLM
(BigBird model)DistilBertConfig
configuration class:FlaxDistilBertForMaskedLM
(DistilBERT model)ElectraConfig
configuration class:FlaxElectraForMaskedLM
(ELECTRA model)MBartConfig
configuration class:FlaxMBartForConditionalGeneration
(mBART model)RoFormerConfig
configuration class:FlaxRoFormerForMaskedLM
(RoFormer model)RobertaConfig
configuration class:FlaxRobertaForMaskedLM
(RoBERTa model)RobertaPreLayerNormConfig
configuration class:FlaxRobertaPreLayerNormForMaskedLM
(RoBERTa-PreLayerNorm model)XLMRobertaConfig
configuration class:FlaxXLMRobertaForMaskedLM
(XLM-RoBERTa model)
- attn_implementation (
str
, optional) — The attention implementation to use in the model (if relevant). Can be any of"eager"
(manual implementation of the attention),"sdpa"
(usingF.scaled_dot_product_attention
), or"flash_attention_2"
(using Dao-AILab/flash-attention). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual"eager"
implementation.
Instantiates one of the model classes of the library (with a masked language modeling head) from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the model’s configuration. Use from_pretrained() to load the model weights.
from_pretrained
< source >( *model_args **kwargs )
Parameters
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - from_pt (
bool
, optional, defaults toFalse
) — Load the model weights from a PyTorch checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). - force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download — Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. - output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). - revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. - code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the model classes of the library (with a masked language modeling head) from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when it’s missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- albert —
FlaxAlbertForMaskedLM
(ALBERT model) - bart — FlaxBartForConditionalGeneration (BART model)
- bert — FlaxBertForMaskedLM (BERT model)
- big_bird —
FlaxBigBirdForMaskedLM
(BigBird model) - distilbert —
FlaxDistilBertForMaskedLM
(DistilBERT model) - electra —
FlaxElectraForMaskedLM
(ELECTRA model) - mbart —
FlaxMBartForConditionalGeneration
(mBART model) - roberta —
FlaxRobertaForMaskedLM
(RoBERTa model) - roberta-prelayernorm —
FlaxRobertaPreLayerNormForMaskedLM
(RoBERTa-PreLayerNorm model) - roformer —
FlaxRoFormerForMaskedLM
(RoFormer model) - xlm-roberta —
FlaxXLMRobertaForMaskedLM
(XLM-RoBERTa model)
Examples:
>>> from transformers import AutoConfig, FlaxAutoModelForMaskedLM
>>> # Download model and configuration from huggingface.co and cache.
>>> model = FlaxAutoModelForMaskedLM.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = FlaxAutoModelForMaskedLM.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = FlaxAutoModelForMaskedLM.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
AutoModelForMaskGeneration
TFAutoModelForMaskGeneration
AutoModelForSeq2SeqLM
This is a generic model class that will be instantiated as one of the model classes of the library (with a sequence-to-sequence language modeling head) when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- BartConfig configuration class: BartForConditionalGeneration (BART model)
BigBirdPegasusConfig
configuration class:BigBirdPegasusForConditionalGeneration
(BigBird-Pegasus model)BlenderbotConfig
configuration class:BlenderbotForConditionalGeneration
(Blenderbot model)BlenderbotSmallConfig
configuration class:BlenderbotSmallForConditionalGeneration
(BlenderbotSmall model)- EncoderDecoderConfig configuration class: EncoderDecoderModel (Encoder decoder model)
FSMTConfig
configuration class:FSMTForConditionalGeneration
(FairSeq Machine-Translation model)GPTSanJapaneseConfig
configuration class:GPTSanJapaneseForConditionalGeneration
(GPTSAN-japanese model)LEDConfig
configuration class:LEDForConditionalGeneration
(LED model)LongT5Config
configuration class:LongT5ForConditionalGeneration
(LongT5 model)M2M100Config
configuration class:M2M100ForConditionalGeneration
(M2M100 model)MBartConfig
configuration class:MBartForConditionalGeneration
(mBART model)MT5Config
configuration class:MT5ForConditionalGeneration
(MT5 model)- MarianConfig configuration class: MarianMTModel (Marian model)
MvpConfig
configuration class:MvpForConditionalGeneration
(MVP model)NllbMoeConfig
configuration class:NllbMoeForConditionalGeneration
(NLLB-MOE model)PLBartConfig
configuration class:PLBartForConditionalGeneration
(PLBart model)PegasusConfig
configuration class:PegasusForConditionalGeneration
(Pegasus model)PegasusXConfig
configuration class:PegasusXForConditionalGeneration
(PEGASUS-X model)ProphetNetConfig
configuration class:ProphetNetForConditionalGeneration
(ProphetNet model)Qwen2AudioConfig
configuration class:Qwen2AudioForConditionalGeneration
(Qwen2Audio model)SeamlessM4TConfig
configuration class:SeamlessM4TForTextToText
(SeamlessM4T model)SeamlessM4Tv2Config
configuration class:SeamlessM4Tv2ForTextToText
(SeamlessM4Tv2 model)SwitchTransformersConfig
configuration class:SwitchTransformersForConditionalGeneration
(SwitchTransformers model)T5Config
configuration class:T5ForConditionalGeneration
(T5 model)UMT5Config
configuration class:UMT5ForConditionalGeneration
(UMT5 model)XLMProphetNetConfig
configuration class:XLMProphetNetForConditionalGeneration
(XLM-ProphetNet model)
- attn_implementation (
str
, optional) — The attention implementation to use in the model (if relevant). Can be any of"eager"
(manual implementation of the attention),"sdpa"
(usingF.scaled_dot_product_attention
), or"flash_attention_2"
(using Dao-AILab/flash-attention). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual"eager"
implementation.
Instantiates one of the model classes of the library (with a sequence-to-sequence language modeling head) from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the model’s configuration. Use from_pretrained() to load the model weights.
from_pretrained
< source >( *model_args **kwargs )
Parameters
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a tensorflow index checkpoint file (e.g,
./tf_model/model.ckpt.index
). In this case,from_tf
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- state_dict (Dict[str, torch.Tensor], optional) —
A state dictionary to use instead of a state dictionary loaded from saved weights file.
This option can be used if you want to create a model from a pretrained configuration but load your own weights. In this case though, you should check if using save_pretrained() and from_pretrained() is not a simpler option.
- cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - from_tf (
bool
, optional, defaults toFalse
) — Load the model weights from a TensorFlow checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). - force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download — Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. - output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). - revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. - code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the model classes of the library (with a sequence-to-sequence language modeling head) from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when it’s missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- bart — BartForConditionalGeneration (BART model)
- bigbird_pegasus —
BigBirdPegasusForConditionalGeneration
(BigBird-Pegasus model) - blenderbot —
BlenderbotForConditionalGeneration
(Blenderbot model) - blenderbot-small —
BlenderbotSmallForConditionalGeneration
(BlenderbotSmall model) - encoder-decoder — EncoderDecoderModel (Encoder decoder model)
- fsmt —
FSMTForConditionalGeneration
(FairSeq Machine-Translation model) - gptsan-japanese —
GPTSanJapaneseForConditionalGeneration
(GPTSAN-japanese model) - led —
LEDForConditionalGeneration
(LED model) - longt5 —
LongT5ForConditionalGeneration
(LongT5 model) - m2m_100 —
M2M100ForConditionalGeneration
(M2M100 model) - marian — MarianMTModel (Marian model)
- mbart —
MBartForConditionalGeneration
(mBART model) - mt5 —
MT5ForConditionalGeneration
(MT5 model) - mvp —
MvpForConditionalGeneration
(MVP model) - nllb-moe —
NllbMoeForConditionalGeneration
(NLLB-MOE model) - pegasus —
PegasusForConditionalGeneration
(Pegasus model) - pegasus_x —
PegasusXForConditionalGeneration
(PEGASUS-X model) - plbart —
PLBartForConditionalGeneration
(PLBart model) - prophetnet —
ProphetNetForConditionalGeneration
(ProphetNet model) - qwen2_audio —
Qwen2AudioForConditionalGeneration
(Qwen2Audio model) - seamless_m4t —
SeamlessM4TForTextToText
(SeamlessM4T model) - seamless_m4t_v2 —
SeamlessM4Tv2ForTextToText
(SeamlessM4Tv2 model) - switch_transformers —
SwitchTransformersForConditionalGeneration
(SwitchTransformers model) - t5 —
T5ForConditionalGeneration
(T5 model) - umt5 —
UMT5ForConditionalGeneration
(UMT5 model) - xlm-prophetnet —
XLMProphetNetForConditionalGeneration
(XLM-ProphetNet model)
The model is set in evaluation mode by default using model.eval()
(so for instance, dropout modules are
deactivated). To train the model, you should first set it back in training mode with model.train()
Examples:
>>> from transformers import AutoConfig, AutoModelForSeq2SeqLM
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-base")
>>> # Update configuration during loading
>>> model = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-base", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/t5_tf_model_config.json")
>>> model = AutoModelForSeq2SeqLM.from_pretrained(
... "./tf_model/t5_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
TFAutoModelForSeq2SeqLM
This is a generic model class that will be instantiated as one of the model classes of the library (with a sequence-to-sequence language modeling head) when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- BartConfig configuration class: TFBartForConditionalGeneration (BART model)
BlenderbotConfig
configuration class:TFBlenderbotForConditionalGeneration
(Blenderbot model)BlenderbotSmallConfig
configuration class:TFBlenderbotSmallForConditionalGeneration
(BlenderbotSmall model)- EncoderDecoderConfig configuration class: TFEncoderDecoderModel (Encoder decoder model)
LEDConfig
configuration class:TFLEDForConditionalGeneration
(LED model)MBartConfig
configuration class:TFMBartForConditionalGeneration
(mBART model)MT5Config
configuration class:TFMT5ForConditionalGeneration
(MT5 model)- MarianConfig configuration class: TFMarianMTModel (Marian model)
PegasusConfig
configuration class:TFPegasusForConditionalGeneration
(Pegasus model)T5Config
configuration class:TFT5ForConditionalGeneration
(T5 model)
- attn_implementation (
str
, optional) — The attention implementation to use in the model (if relevant). Can be any of"eager"
(manual implementation of the attention),"sdpa"
(usingF.scaled_dot_product_attention
), or"flash_attention_2"
(using Dao-AILab/flash-attention). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual"eager"
implementation.
Instantiates one of the model classes of the library (with a sequence-to-sequence language modeling head) from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the model’s configuration. Use from_pretrained() to load the model weights.
from_pretrained
< source >( *model_args **kwargs )
Parameters
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - from_pt (
bool
, optional, defaults toFalse
) — Load the model weights from a PyTorch checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). - force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download — Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. - output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). - revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. - code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the model classes of the library (with a sequence-to-sequence language modeling head) from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when it’s missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- bart — TFBartForConditionalGeneration (BART model)
- blenderbot —
TFBlenderbotForConditionalGeneration
(Blenderbot model) - blenderbot-small —
TFBlenderbotSmallForConditionalGeneration
(BlenderbotSmall model) - encoder-decoder — TFEncoderDecoderModel (Encoder decoder model)
- led —
TFLEDForConditionalGeneration
(LED model) - marian — TFMarianMTModel (Marian model)
- mbart —
TFMBartForConditionalGeneration
(mBART model) - mt5 —
TFMT5ForConditionalGeneration
(MT5 model) - pegasus —
TFPegasusForConditionalGeneration
(Pegasus model) - t5 —
TFT5ForConditionalGeneration
(T5 model)
Examples:
>>> from transformers import AutoConfig, TFAutoModelForSeq2SeqLM
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-base")
>>> # Update configuration during loading
>>> model = TFAutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-base", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/t5_pt_model_config.json")
>>> model = TFAutoModelForSeq2SeqLM.from_pretrained(
... "./pt_model/t5_pytorch_model.bin", from_pt=True, config=config
... )
FlaxAutoModelForSeq2SeqLM
This is a generic model class that will be instantiated as one of the model classes of the library (with a sequence-to-sequence language modeling head) when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- BartConfig configuration class: FlaxBartForConditionalGeneration (BART model)
BlenderbotConfig
configuration class:FlaxBlenderbotForConditionalGeneration
(Blenderbot model)BlenderbotSmallConfig
configuration class:FlaxBlenderbotSmallForConditionalGeneration
(BlenderbotSmall model)- EncoderDecoderConfig configuration class: FlaxEncoderDecoderModel (Encoder decoder model)
LongT5Config
configuration class:FlaxLongT5ForConditionalGeneration
(LongT5 model)MBartConfig
configuration class:FlaxMBartForConditionalGeneration
(mBART model)MT5Config
configuration class:FlaxMT5ForConditionalGeneration
(MT5 model)- MarianConfig configuration class: FlaxMarianMTModel (Marian model)
PegasusConfig
configuration class:FlaxPegasusForConditionalGeneration
(Pegasus model)T5Config
configuration class:FlaxT5ForConditionalGeneration
(T5 model)
- attn_implementation (
str
, optional) — The attention implementation to use in the model (if relevant). Can be any of"eager"
(manual implementation of the attention),"sdpa"
(usingF.scaled_dot_product_attention
), or"flash_attention_2"
(using Dao-AILab/flash-attention). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual"eager"
implementation.
Instantiates one of the model classes of the library (with a sequence-to-sequence language modeling head) from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the model’s configuration. Use from_pretrained() to load the model weights.
from_pretrained
< source >( *model_args **kwargs )
Parameters
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - from_pt (
bool
, optional, defaults toFalse
) — Load the model weights from a PyTorch checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). - force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download — Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. - output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). - revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. - code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the model classes of the library (with a sequence-to-sequence language modeling head) from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when it’s missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- bart — FlaxBartForConditionalGeneration (BART model)
- blenderbot —
FlaxBlenderbotForConditionalGeneration
(Blenderbot model) - blenderbot-small —
FlaxBlenderbotSmallForConditionalGeneration
(BlenderbotSmall model) - encoder-decoder — FlaxEncoderDecoderModel (Encoder decoder model)
- longt5 —
FlaxLongT5ForConditionalGeneration
(LongT5 model) - marian — FlaxMarianMTModel (Marian model)
- mbart —
FlaxMBartForConditionalGeneration
(mBART model) - mt5 —
FlaxMT5ForConditionalGeneration
(MT5 model) - pegasus —
FlaxPegasusForConditionalGeneration
(Pegasus model) - t5 —
FlaxT5ForConditionalGeneration
(T5 model)
Examples:
>>> from transformers import AutoConfig, FlaxAutoModelForSeq2SeqLM
>>> # Download model and configuration from huggingface.co and cache.
>>> model = FlaxAutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-base")
>>> # Update configuration during loading
>>> model = FlaxAutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-base", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/t5_pt_model_config.json")
>>> model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
... "./pt_model/t5_pytorch_model.bin", from_pt=True, config=config
... )
AutoModelForSequenceClassification
This is a generic model class that will be instantiated as one of the model classes of the library (with a sequence classification head) when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
AlbertConfig
configuration class:AlbertForSequenceClassification
(ALBERT model)- BartConfig configuration class: BartForSequenceClassification (BART model)
- BertConfig configuration class: BertForSequenceClassification (BERT model)
BigBirdConfig
configuration class:BigBirdForSequenceClassification
(BigBird model)BigBirdPegasusConfig
configuration class:BigBirdPegasusForSequenceClassification
(BigBird-Pegasus model)- BioGptConfig configuration class: BioGptForSequenceClassification (BioGpt model)
BloomConfig
configuration class:BloomForSequenceClassification
(BLOOM model)CTRLConfig
configuration class:CTRLForSequenceClassification
(CTRL model)CamembertConfig
configuration class:CamembertForSequenceClassification
(CamemBERT model)CanineConfig
configuration class:CanineForSequenceClassification
(CANINE model)- ConvBertConfig configuration class: ConvBertForSequenceClassification (ConvBERT model)
Data2VecTextConfig
configuration class:Data2VecTextForSequenceClassification
(Data2VecText model)- DebertaConfig configuration class: DebertaForSequenceClassification (DeBERTa model)
- DebertaV2Config configuration class: DebertaV2ForSequenceClassification (DeBERTa-v2 model)
DiffLlamaConfig
configuration class:DiffLlamaForSequenceClassification
(DiffLlama model)DistilBertConfig
configuration class:DistilBertForSequenceClassification
(DistilBERT model)ElectraConfig
configuration class:ElectraForSequenceClassification
(ELECTRA model)ErnieConfig
configuration class:ErnieForSequenceClassification
(ERNIE model)ErnieMConfig
configuration class:ErnieMForSequenceClassification
(ErnieM model)- EsmConfig configuration class: EsmForSequenceClassification (ESM model)
FNetConfig
configuration class:FNetForSequenceClassification
(FNet model)FalconConfig
configuration class:FalconForSequenceClassification
(Falcon model)FlaubertConfig
configuration class:FlaubertForSequenceClassification
(FlauBERT model)FunnelConfig
configuration class:FunnelForSequenceClassification
(Funnel Transformer model)GPT2Config
configuration class:GPT2ForSequenceClassification
(OpenAI GPT-2 model)GPTBigCodeConfig
configuration class:GPTBigCodeForSequenceClassification
(GPTBigCode model)GPTJConfig
configuration class:GPTJForSequenceClassification
(GPT-J model)GPTNeoConfig
configuration class:GPTNeoForSequenceClassification
(GPT Neo model)GPTNeoXConfig
configuration class:GPTNeoXForSequenceClassification
(GPT NeoX model)- Gemma2Config configuration class: Gemma2ForSequenceClassification (Gemma2 model)
- GemmaConfig configuration class: GemmaForSequenceClassification (Gemma model)
GlmConfig
configuration class:GlmForSequenceClassification
(GLM model)HeliumConfig
configuration class:HeliumForSequenceClassification
(Helium model)IBertConfig
configuration class:IBertForSequenceClassification
(I-BERT model)JambaConfig
configuration class:JambaForSequenceClassification
(Jamba model)JetMoeConfig
configuration class:JetMoeForSequenceClassification
(JetMoe model)LEDConfig
configuration class:LEDForSequenceClassification
(LED model)LayoutLMConfig
configuration class:LayoutLMForSequenceClassification
(LayoutLM model)LayoutLMv2Config
configuration class:LayoutLMv2ForSequenceClassification
(LayoutLMv2 model)LayoutLMv3Config
configuration class:LayoutLMv3ForSequenceClassification
(LayoutLMv3 model)LiltConfig
configuration class:LiltForSequenceClassification
(LiLT model)- LlamaConfig configuration class: LlamaForSequenceClassification (LLaMA model)
LongformerConfig
configuration class:LongformerForSequenceClassification
(Longformer model)LukeConfig
configuration class:LukeForSequenceClassification
(LUKE model)MBartConfig
configuration class:MBartForSequenceClassification
(mBART model)MPNetConfig
configuration class:MPNetForSequenceClassification
(MPNet model)MT5Config
configuration class:MT5ForSequenceClassification
(MT5 model)MarkupLMConfig
configuration class:MarkupLMForSequenceClassification
(MarkupLM model)MegaConfig
configuration class:MegaForSequenceClassification
(MEGA model)MegatronBertConfig
configuration class:MegatronBertForSequenceClassification
(Megatron-BERT model)- MistralConfig configuration class: MistralForSequenceClassification (Mistral model)
MixtralConfig
configuration class:MixtralForSequenceClassification
(Mixtral model)MobileBertConfig
configuration class:MobileBertForSequenceClassification
(MobileBERT model)ModernBertConfig
configuration class:ModernBertForSequenceClassification
(ModernBERT model)MptConfig
configuration class:MptForSequenceClassification
(MPT model)MraConfig
configuration class:MraForSequenceClassification
(MRA model)MvpConfig
configuration class:MvpForSequenceClassification
(MVP model)NemotronConfig
configuration class:NemotronForSequenceClassification
(Nemotron model)NezhaConfig
configuration class:NezhaForSequenceClassification
(Nezha model)NystromformerConfig
configuration class:NystromformerForSequenceClassification
(Nyströmformer model)OPTConfig
configuration class:OPTForSequenceClassification
(OPT model)- OpenAIGPTConfig configuration class: OpenAIGPTForSequenceClassification (OpenAI GPT model)
OpenLlamaConfig
configuration class:OpenLlamaForSequenceClassification
(OpenLlama model)PLBartConfig
configuration class:PLBartForSequenceClassification
(PLBart model)PerceiverConfig
configuration class:PerceiverForSequenceClassification
(Perceiver model)PersimmonConfig
configuration class:PersimmonForSequenceClassification
(Persimmon model)Phi3Config
configuration class:Phi3ForSequenceClassification
(Phi3 model)PhiConfig
configuration class:PhiForSequenceClassification
(Phi model)PhimoeConfig
configuration class:PhimoeForSequenceClassification
(Phimoe model)QDQBertConfig
configuration class:QDQBertForSequenceClassification
(QDQBert model)Qwen2Config
configuration class:Qwen2ForSequenceClassification
(Qwen2 model)Qwen2MoeConfig
configuration class:Qwen2MoeForSequenceClassification
(Qwen2MoE model)ReformerConfig
configuration class:ReformerForSequenceClassification
(Reformer model)RemBertConfig
configuration class:RemBertForSequenceClassification
(RemBERT model)RoCBertConfig
configuration class:RoCBertForSequenceClassification
(RoCBert model)RoFormerConfig
configuration class:RoFormerForSequenceClassification
(RoFormer model)RobertaConfig
configuration class:RobertaForSequenceClassification
(RoBERTa model)RobertaPreLayerNormConfig
configuration class:RobertaPreLayerNormForSequenceClassification
(RoBERTa-PreLayerNorm model)SqueezeBertConfig
configuration class:SqueezeBertForSequenceClassification
(SqueezeBERT model)StableLmConfig
configuration class:StableLmForSequenceClassification
(StableLm model)Starcoder2Config
configuration class:Starcoder2ForSequenceClassification
(Starcoder2 model)T5Config
configuration class:T5ForSequenceClassification
(T5 model)TapasConfig
configuration class:TapasForSequenceClassification
(TAPAS model)TransfoXLConfig
configuration class:TransfoXLForSequenceClassification
(Transformer-XL model)UMT5Config
configuration class:UMT5ForSequenceClassification
(UMT5 model)XLMConfig
configuration class:XLMForSequenceClassification
(XLM model)XLMRobertaConfig
configuration class:XLMRobertaForSequenceClassification
(XLM-RoBERTa model)XLMRobertaXLConfig
configuration class:XLMRobertaXLForSequenceClassification
(XLM-RoBERTa-XL model)XLNetConfig
configuration class:XLNetForSequenceClassification
(XLNet model)XmodConfig
configuration class:XmodForSequenceClassification
(X-MOD model)YosoConfig
configuration class:YosoForSequenceClassification
(YOSO model)ZambaConfig
configuration class:ZambaForSequenceClassification
(Zamba model)
- attn_implementation (
str
, optional) — The attention implementation to use in the model (if relevant). Can be any of"eager"
(manual implementation of the attention),"sdpa"
(usingF.scaled_dot_product_attention
), or"flash_attention_2"
(using Dao-AILab/flash-attention). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual"eager"
implementation.
Instantiates one of the model classes of the library (with a sequence classification head) from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the model’s configuration. Use from_pretrained() to load the model weights.
from_pretrained
< source >( *model_args **kwargs )
Parameters
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a tensorflow index checkpoint file (e.g,
./tf_model/model.ckpt.index
). In this case,from_tf
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- state_dict (Dict[str, torch.Tensor], optional) —
A state dictionary to use instead of a state dictionary loaded from saved weights file.
This option can be used if you want to create a model from a pretrained configuration but load your own weights. In this case though, you should check if using save_pretrained() and from_pretrained() is not a simpler option.
- cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - from_tf (
bool
, optional, defaults toFalse
) — Load the model weights from a TensorFlow checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). - force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download — Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. - output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). - revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. - code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the model classes of the library (with a sequence classification head) from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when it’s missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- albert —
AlbertForSequenceClassification
(ALBERT model) - bart — BartForSequenceClassification (BART model)
- bert — BertForSequenceClassification (BERT model)
- big_bird —
BigBirdForSequenceClassification
(BigBird model) - bigbird_pegasus —
BigBirdPegasusForSequenceClassification
(BigBird-Pegasus model) - biogpt — BioGptForSequenceClassification (BioGpt model)
- bloom —
BloomForSequenceClassification
(BLOOM model) - camembert —
CamembertForSequenceClassification
(CamemBERT model) - canine —
CanineForSequenceClassification
(CANINE model) - code_llama — LlamaForSequenceClassification (CodeLlama model)
- convbert — ConvBertForSequenceClassification (ConvBERT model)
- ctrl —
CTRLForSequenceClassification
(CTRL model) - data2vec-text —
Data2VecTextForSequenceClassification
(Data2VecText model) - deberta — DebertaForSequenceClassification (DeBERTa model)
- deberta-v2 — DebertaV2ForSequenceClassification (DeBERTa-v2 model)
- diffllama —
DiffLlamaForSequenceClassification
(DiffLlama model) - distilbert —
DistilBertForSequenceClassification
(DistilBERT model) - electra —
ElectraForSequenceClassification
(ELECTRA model) - ernie —
ErnieForSequenceClassification
(ERNIE model) - ernie_m —
ErnieMForSequenceClassification
(ErnieM model) - esm — EsmForSequenceClassification (ESM model)
- falcon —
FalconForSequenceClassification
(Falcon model) - flaubert —
FlaubertForSequenceClassification
(FlauBERT model) - fnet —
FNetForSequenceClassification
(FNet model) - funnel —
FunnelForSequenceClassification
(Funnel Transformer model) - gemma — GemmaForSequenceClassification (Gemma model)
- gemma2 — Gemma2ForSequenceClassification (Gemma2 model)
- glm —
GlmForSequenceClassification
(GLM model) - gpt-sw3 —
GPT2ForSequenceClassification
(GPT-Sw3 model) - gpt2 —
GPT2ForSequenceClassification
(OpenAI GPT-2 model) - gpt_bigcode —
GPTBigCodeForSequenceClassification
(GPTBigCode model) - gpt_neo —
GPTNeoForSequenceClassification
(GPT Neo model) - gpt_neox —
GPTNeoXForSequenceClassification
(GPT NeoX model) - gptj —
GPTJForSequenceClassification
(GPT-J model) - helium —
HeliumForSequenceClassification
(Helium model) - ibert —
IBertForSequenceClassification
(I-BERT model) - jamba —
JambaForSequenceClassification
(Jamba model) - jetmoe —
JetMoeForSequenceClassification
(JetMoe model) - layoutlm —
LayoutLMForSequenceClassification
(LayoutLM model) - layoutlmv2 —
LayoutLMv2ForSequenceClassification
(LayoutLMv2 model) - layoutlmv3 —
LayoutLMv3ForSequenceClassification
(LayoutLMv3 model) - led —
LEDForSequenceClassification
(LED model) - lilt —
LiltForSequenceClassification
(LiLT model) - llama — LlamaForSequenceClassification (LLaMA model)
- longformer —
LongformerForSequenceClassification
(Longformer model) - luke —
LukeForSequenceClassification
(LUKE model) - markuplm —
MarkupLMForSequenceClassification
(MarkupLM model) - mbart —
MBartForSequenceClassification
(mBART model) - mega —
MegaForSequenceClassification
(MEGA model) - megatron-bert —
MegatronBertForSequenceClassification
(Megatron-BERT model) - mistral — MistralForSequenceClassification (Mistral model)
- mixtral —
MixtralForSequenceClassification
(Mixtral model) - mobilebert —
MobileBertForSequenceClassification
(MobileBERT model) - modernbert —
ModernBertForSequenceClassification
(ModernBERT model) - mpnet —
MPNetForSequenceClassification
(MPNet model) - mpt —
MptForSequenceClassification
(MPT model) - mra —
MraForSequenceClassification
(MRA model) - mt5 —
MT5ForSequenceClassification
(MT5 model) - mvp —
MvpForSequenceClassification
(MVP model) - nemotron —
NemotronForSequenceClassification
(Nemotron model) - nezha —
NezhaForSequenceClassification
(Nezha model) - nystromformer —
NystromformerForSequenceClassification
(Nyströmformer model) - open-llama —
OpenLlamaForSequenceClassification
(OpenLlama model) - openai-gpt — OpenAIGPTForSequenceClassification (OpenAI GPT model)
- opt —
OPTForSequenceClassification
(OPT model) - perceiver —
PerceiverForSequenceClassification
(Perceiver model) - persimmon —
PersimmonForSequenceClassification
(Persimmon model) - phi —
PhiForSequenceClassification
(Phi model) - phi3 —
Phi3ForSequenceClassification
(Phi3 model) - phimoe —
PhimoeForSequenceClassification
(Phimoe model) - plbart —
PLBartForSequenceClassification
(PLBart model) - qdqbert —
QDQBertForSequenceClassification
(QDQBert model) - qwen2 —
Qwen2ForSequenceClassification
(Qwen2 model) - qwen2_moe —
Qwen2MoeForSequenceClassification
(Qwen2MoE model) - reformer —
ReformerForSequenceClassification
(Reformer model) - rembert —
RemBertForSequenceClassification
(RemBERT model) - roberta —
RobertaForSequenceClassification
(RoBERTa model) - roberta-prelayernorm —
RobertaPreLayerNormForSequenceClassification
(RoBERTa-PreLayerNorm model) - roc_bert —
RoCBertForSequenceClassification
(RoCBert model) - roformer —
RoFormerForSequenceClassification
(RoFormer model) - squeezebert —
SqueezeBertForSequenceClassification
(SqueezeBERT model) - stablelm —
StableLmForSequenceClassification
(StableLm model) - starcoder2 —
Starcoder2ForSequenceClassification
(Starcoder2 model) - t5 —
T5ForSequenceClassification
(T5 model) - tapas —
TapasForSequenceClassification
(TAPAS model) - transfo-xl —
TransfoXLForSequenceClassification
(Transformer-XL model) - umt5 —
UMT5ForSequenceClassification
(UMT5 model) - xlm —
XLMForSequenceClassification
(XLM model) - xlm-roberta —
XLMRobertaForSequenceClassification
(XLM-RoBERTa model) - xlm-roberta-xl —
XLMRobertaXLForSequenceClassification
(XLM-RoBERTa-XL model) - xlnet —
XLNetForSequenceClassification
(XLNet model) - xmod —
XmodForSequenceClassification
(X-MOD model) - yoso —
YosoForSequenceClassification
(YOSO model) - zamba —
ZambaForSequenceClassification
(Zamba model)
The model is set in evaluation mode by default using model.eval()
(so for instance, dropout modules are
deactivated). To train the model, you should first set it back in training mode with model.train()
Examples:
>>> from transformers import AutoConfig, AutoModelForSequenceClassification
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForSequenceClassification.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForSequenceClassification.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForSequenceClassification.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
TFAutoModelForSequenceClassification
This is a generic model class that will be instantiated as one of the model classes of the library (with a sequence classification head) when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
AlbertConfig
configuration class:TFAlbertForSequenceClassification
(ALBERT model)- BartConfig configuration class: TFBartForSequenceClassification (BART model)
- BertConfig configuration class: TFBertForSequenceClassification (BERT model)
CTRLConfig
configuration class:TFCTRLForSequenceClassification
(CTRL model)CamembertConfig
configuration class:TFCamembertForSequenceClassification
(CamemBERT model)- ConvBertConfig configuration class: TFConvBertForSequenceClassification (ConvBERT model)
- DebertaConfig configuration class: TFDebertaForSequenceClassification (DeBERTa model)
- DebertaV2Config configuration class: TFDebertaV2ForSequenceClassification (DeBERTa-v2 model)
DistilBertConfig
configuration class:TFDistilBertForSequenceClassification
(DistilBERT model)ElectraConfig
configuration class:TFElectraForSequenceClassification
(ELECTRA model)- EsmConfig configuration class: TFEsmForSequenceClassification (ESM model)
FlaubertConfig
configuration class:TFFlaubertForSequenceClassification
(FlauBERT model)FunnelConfig
configuration class:TFFunnelForSequenceClassification
(Funnel Transformer model)GPT2Config
configuration class:TFGPT2ForSequenceClassification
(OpenAI GPT-2 model)GPTJConfig
configuration class:TFGPTJForSequenceClassification
(GPT-J model)LayoutLMConfig
configuration class:TFLayoutLMForSequenceClassification
(LayoutLM model)LayoutLMv3Config
configuration class:TFLayoutLMv3ForSequenceClassification
(LayoutLMv3 model)LongformerConfig
configuration class:TFLongformerForSequenceClassification
(Longformer model)MPNetConfig
configuration class:TFMPNetForSequenceClassification
(MPNet model)- MistralConfig configuration class: TFMistralForSequenceClassification (Mistral model)
MobileBertConfig
configuration class:TFMobileBertForSequenceClassification
(MobileBERT model)- OpenAIGPTConfig configuration class: TFOpenAIGPTForSequenceClassification (OpenAI GPT model)
RemBertConfig
configuration class:TFRemBertForSequenceClassification
(RemBERT model)RoFormerConfig
configuration class:TFRoFormerForSequenceClassification
(RoFormer model)RobertaConfig
configuration class:TFRobertaForSequenceClassification
(RoBERTa model)RobertaPreLayerNormConfig
configuration class:TFRobertaPreLayerNormForSequenceClassification
(RoBERTa-PreLayerNorm model)TapasConfig
configuration class:TFTapasForSequenceClassification
(TAPAS model)TransfoXLConfig
configuration class:TFTransfoXLForSequenceClassification
(Transformer-XL model)XLMConfig
configuration class:TFXLMForSequenceClassification
(XLM model)XLMRobertaConfig
configuration class:TFXLMRobertaForSequenceClassification
(XLM-RoBERTa model)XLNetConfig
configuration class:TFXLNetForSequenceClassification
(XLNet model)
- attn_implementation (
str
, optional) — The attention implementation to use in the model (if relevant). Can be any of"eager"
(manual implementation of the attention),"sdpa"
(usingF.scaled_dot_product_attention
), or"flash_attention_2"
(using Dao-AILab/flash-attention). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual"eager"
implementation.
Instantiates one of the model classes of the library (with a sequence classification head) from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the model’s configuration. Use from_pretrained() to load the model weights.
from_pretrained
< source >( *model_args **kwargs )
Parameters
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - from_pt (
bool
, optional, defaults toFalse
) — Load the model weights from a PyTorch checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). - force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download — Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. - output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). - revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. - code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the model classes of the library (with a sequence classification head) from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when it’s missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- albert —
TFAlbertForSequenceClassification
(ALBERT model) - bart — TFBartForSequenceClassification (BART model)
- bert — TFBertForSequenceClassification (BERT model)
- camembert —
TFCamembertForSequenceClassification
(CamemBERT model) - convbert — TFConvBertForSequenceClassification (ConvBERT model)
- ctrl —
TFCTRLForSequenceClassification
(CTRL model) - deberta — TFDebertaForSequenceClassification (DeBERTa model)
- deberta-v2 — TFDebertaV2ForSequenceClassification (DeBERTa-v2 model)
- distilbert —
TFDistilBertForSequenceClassification
(DistilBERT model) - electra —
TFElectraForSequenceClassification
(ELECTRA model) - esm — TFEsmForSequenceClassification (ESM model)
- flaubert —
TFFlaubertForSequenceClassification
(FlauBERT model) - funnel —
TFFunnelForSequenceClassification
(Funnel Transformer model) - gpt-sw3 —
TFGPT2ForSequenceClassification
(GPT-Sw3 model) - gpt2 —
TFGPT2ForSequenceClassification
(OpenAI GPT-2 model) - gptj —
TFGPTJForSequenceClassification
(GPT-J model) - layoutlm —
TFLayoutLMForSequenceClassification
(LayoutLM model) - layoutlmv3 —
TFLayoutLMv3ForSequenceClassification
(LayoutLMv3 model) - longformer —
TFLongformerForSequenceClassification
(Longformer model) - mistral — TFMistralForSequenceClassification (Mistral model)
- mobilebert —
TFMobileBertForSequenceClassification
(MobileBERT model) - mpnet —
TFMPNetForSequenceClassification
(MPNet model) - openai-gpt — TFOpenAIGPTForSequenceClassification (OpenAI GPT model)
- rembert —
TFRemBertForSequenceClassification
(RemBERT model) - roberta —
TFRobertaForSequenceClassification
(RoBERTa model) - roberta-prelayernorm —
TFRobertaPreLayerNormForSequenceClassification
(RoBERTa-PreLayerNorm model) - roformer —
TFRoFormerForSequenceClassification
(RoFormer model) - tapas —
TFTapasForSequenceClassification
(TAPAS model) - transfo-xl —
TFTransfoXLForSequenceClassification
(Transformer-XL model) - xlm —
TFXLMForSequenceClassification
(XLM model) - xlm-roberta —
TFXLMRobertaForSequenceClassification
(XLM-RoBERTa model) - xlnet —
TFXLNetForSequenceClassification
(XLNet model)
Examples:
>>> from transformers import AutoConfig, TFAutoModelForSequenceClassification
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForSequenceClassification.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = TFAutoModelForSequenceClassification.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = TFAutoModelForSequenceClassification.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
FlaxAutoModelForSequenceClassification
This is a generic model class that will be instantiated as one of the model classes of the library (with a sequence classification head) when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
AlbertConfig
configuration class:FlaxAlbertForSequenceClassification
(ALBERT model)- BartConfig configuration class: FlaxBartForSequenceClassification (BART model)
- BertConfig configuration class: FlaxBertForSequenceClassification (BERT model)
BigBirdConfig
configuration class:FlaxBigBirdForSequenceClassification
(BigBird model)DistilBertConfig
configuration class:FlaxDistilBertForSequenceClassification
(DistilBERT model)ElectraConfig
configuration class:FlaxElectraForSequenceClassification
(ELECTRA model)MBartConfig
configuration class:FlaxMBartForSequenceClassification
(mBART model)RoFormerConfig
configuration class:FlaxRoFormerForSequenceClassification
(RoFormer model)RobertaConfig
configuration class:FlaxRobertaForSequenceClassification
(RoBERTa model)RobertaPreLayerNormConfig
configuration class:FlaxRobertaPreLayerNormForSequenceClassification
(RoBERTa-PreLayerNorm model)XLMRobertaConfig
configuration class:FlaxXLMRobertaForSequenceClassification
(XLM-RoBERTa model)
- attn_implementation (
str
, optional) — The attention implementation to use in the model (if relevant). Can be any of"eager"
(manual implementation of the attention),"sdpa"
(usingF.scaled_dot_product_attention
), or"flash_attention_2"
(using Dao-AILab/flash-attention). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual"eager"
implementation.
Instantiates one of the model classes of the library (with a sequence classification head) from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the model’s configuration. Use from_pretrained() to load the model weights.
from_pretrained
< source >( *model_args **kwargs )
Parameters
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - from_pt (
bool
, optional, defaults toFalse
) — Load the model weights from a PyTorch checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). - force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download — Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. - output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). - revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. - code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the model classes of the library (with a sequence classification head) from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when it’s missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- albert —
FlaxAlbertForSequenceClassification
(ALBERT model) - bart — FlaxBartForSequenceClassification (BART model)
- bert — FlaxBertForSequenceClassification (BERT model)
- big_bird —
FlaxBigBirdForSequenceClassification
(BigBird model) - distilbert —
FlaxDistilBertForSequenceClassification
(DistilBERT model) - electra —
FlaxElectraForSequenceClassification
(ELECTRA model) - mbart —
FlaxMBartForSequenceClassification
(mBART model) - roberta —
FlaxRobertaForSequenceClassification
(RoBERTa model) - roberta-prelayernorm —
FlaxRobertaPreLayerNormForSequenceClassification
(RoBERTa-PreLayerNorm model) - roformer —
FlaxRoFormerForSequenceClassification
(RoFormer model) - xlm-roberta —
FlaxXLMRobertaForSequenceClassification
(XLM-RoBERTa model)
Examples:
>>> from transformers import AutoConfig, FlaxAutoModelForSequenceClassification
>>> # Download model and configuration from huggingface.co and cache.
>>> model = FlaxAutoModelForSequenceClassification.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = FlaxAutoModelForSequenceClassification.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = FlaxAutoModelForSequenceClassification.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
AutoModelForMultipleChoice
This is a generic model class that will be instantiated as one of the model classes of the library (with a multiple choice head) when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
AlbertConfig
configuration class:AlbertForMultipleChoice
(ALBERT model)- BertConfig configuration class: BertForMultipleChoice (BERT model)
BigBirdConfig
configuration class:BigBirdForMultipleChoice
(BigBird model)CamembertConfig
configuration class:CamembertForMultipleChoice
(CamemBERT model)CanineConfig
configuration class:CanineForMultipleChoice
(CANINE model)- ConvBertConfig configuration class: ConvBertForMultipleChoice (ConvBERT model)
Data2VecTextConfig
configuration class:Data2VecTextForMultipleChoice
(Data2VecText model)- DebertaV2Config configuration class: DebertaV2ForMultipleChoice (DeBERTa-v2 model)
DistilBertConfig
configuration class:DistilBertForMultipleChoice
(DistilBERT model)ElectraConfig
configuration class:ElectraForMultipleChoice
(ELECTRA model)ErnieConfig
configuration class:ErnieForMultipleChoice
(ERNIE model)ErnieMConfig
configuration class:ErnieMForMultipleChoice
(ErnieM model)FNetConfig
configuration class:FNetForMultipleChoice
(FNet model)FlaubertConfig
configuration class:FlaubertForMultipleChoice
(FlauBERT model)FunnelConfig
configuration class:FunnelForMultipleChoice
(Funnel Transformer model)IBertConfig
configuration class:IBertForMultipleChoice
(I-BERT model)LongformerConfig
configuration class:LongformerForMultipleChoice
(Longformer model)LukeConfig
configuration class:LukeForMultipleChoice
(LUKE model)MPNetConfig
configuration class:MPNetForMultipleChoice
(MPNet model)MegaConfig
configuration class:MegaForMultipleChoice
(MEGA model)MegatronBertConfig
configuration class:MegatronBertForMultipleChoice
(Megatron-BERT model)MobileBertConfig
configuration class:MobileBertForMultipleChoice
(MobileBERT model)MraConfig
configuration class:MraForMultipleChoice
(MRA model)NezhaConfig
configuration class:NezhaForMultipleChoice
(Nezha model)NystromformerConfig
configuration class:NystromformerForMultipleChoice
(Nyströmformer model)QDQBertConfig
configuration class:QDQBertForMultipleChoice
(QDQBert model)RemBertConfig
configuration class:RemBertForMultipleChoice
(RemBERT model)RoCBertConfig
configuration class:RoCBertForMultipleChoice
(RoCBert model)RoFormerConfig
configuration class:RoFormerForMultipleChoice
(RoFormer model)RobertaConfig
configuration class:RobertaForMultipleChoice
(RoBERTa model)RobertaPreLayerNormConfig
configuration class:RobertaPreLayerNormForMultipleChoice
(RoBERTa-PreLayerNorm model)SqueezeBertConfig
configuration class:SqueezeBertForMultipleChoice
(SqueezeBERT model)XLMConfig
configuration class:XLMForMultipleChoice
(XLM model)XLMRobertaConfig
configuration class:XLMRobertaForMultipleChoice
(XLM-RoBERTa model)XLMRobertaXLConfig
configuration class:XLMRobertaXLForMultipleChoice
(XLM-RoBERTa-XL model)XLNetConfig
configuration class:XLNetForMultipleChoice
(XLNet model)XmodConfig
configuration class:XmodForMultipleChoice
(X-MOD model)YosoConfig
configuration class:YosoForMultipleChoice
(YOSO model)
- attn_implementation (
str
, optional) — The attention implementation to use in the model (if relevant). Can be any of"eager"
(manual implementation of the attention),"sdpa"
(usingF.scaled_dot_product_attention
), or"flash_attention_2"
(using Dao-AILab/flash-attention). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual"eager"
implementation.
Instantiates one of the model classes of the library (with a multiple choice head) from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the model’s configuration. Use from_pretrained() to load the model weights.
from_pretrained
< source >( *model_args **kwargs )
Parameters
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a tensorflow index checkpoint file (e.g,
./tf_model/model.ckpt.index
). In this case,from_tf
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- state_dict (Dict[str, torch.Tensor], optional) —
A state dictionary to use instead of a state dictionary loaded from saved weights file.
This option can be used if you want to create a model from a pretrained configuration but load your own weights. In this case though, you should check if using save_pretrained() and from_pretrained() is not a simpler option.
- cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - from_tf (
bool
, optional, defaults toFalse
) — Load the model weights from a TensorFlow checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). - force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download — Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. - output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). - revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. - code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the model classes of the library (with a multiple choice head) from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when it’s missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- albert —
AlbertForMultipleChoice
(ALBERT model) - bert — BertForMultipleChoice (BERT model)
- big_bird —
BigBirdForMultipleChoice
(BigBird model) - camembert —
CamembertForMultipleChoice
(CamemBERT model) - canine —
CanineForMultipleChoice
(CANINE model) - convbert — ConvBertForMultipleChoice (ConvBERT model)
- data2vec-text —
Data2VecTextForMultipleChoice
(Data2VecText model) - deberta-v2 — DebertaV2ForMultipleChoice (DeBERTa-v2 model)
- distilbert —
DistilBertForMultipleChoice
(DistilBERT model) - electra —
ElectraForMultipleChoice
(ELECTRA model) - ernie —
ErnieForMultipleChoice
(ERNIE model) - ernie_m —
ErnieMForMultipleChoice
(ErnieM model) - flaubert —
FlaubertForMultipleChoice
(FlauBERT model) - fnet —
FNetForMultipleChoice
(FNet model) - funnel —
FunnelForMultipleChoice
(Funnel Transformer model) - ibert —
IBertForMultipleChoice
(I-BERT model) - longformer —
LongformerForMultipleChoice
(Longformer model) - luke —
LukeForMultipleChoice
(LUKE model) - mega —
MegaForMultipleChoice
(MEGA model) - megatron-bert —
MegatronBertForMultipleChoice
(Megatron-BERT model) - mobilebert —
MobileBertForMultipleChoice
(MobileBERT model) - mpnet —
MPNetForMultipleChoice
(MPNet model) - mra —
MraForMultipleChoice
(MRA model) - nezha —
NezhaForMultipleChoice
(Nezha model) - nystromformer —
NystromformerForMultipleChoice
(Nyströmformer model) - qdqbert —
QDQBertForMultipleChoice
(QDQBert model) - rembert —
RemBertForMultipleChoice
(RemBERT model) - roberta —
RobertaForMultipleChoice
(RoBERTa model) - roberta-prelayernorm —
RobertaPreLayerNormForMultipleChoice
(RoBERTa-PreLayerNorm model) - roc_bert —
RoCBertForMultipleChoice
(RoCBert model) - roformer —
RoFormerForMultipleChoice
(RoFormer model) - squeezebert —
SqueezeBertForMultipleChoice
(SqueezeBERT model) - xlm —
XLMForMultipleChoice
(XLM model) - xlm-roberta —
XLMRobertaForMultipleChoice
(XLM-RoBERTa model) - xlm-roberta-xl —
XLMRobertaXLForMultipleChoice
(XLM-RoBERTa-XL model) - xlnet —
XLNetForMultipleChoice
(XLNet model) - xmod —
XmodForMultipleChoice
(X-MOD model) - yoso —
YosoForMultipleChoice
(YOSO model)
The model is set in evaluation mode by default using model.eval()
(so for instance, dropout modules are
deactivated). To train the model, you should first set it back in training mode with model.train()
Examples:
>>> from transformers import AutoConfig, AutoModelForMultipleChoice
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForMultipleChoice.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForMultipleChoice.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForMultipleChoice.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
TFAutoModelForMultipleChoice
This is a generic model class that will be instantiated as one of the model classes of the library (with a multiple choice head) when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
AlbertConfig
configuration class:TFAlbertForMultipleChoice
(ALBERT model)- BertConfig configuration class: TFBertForMultipleChoice (BERT model)
CamembertConfig
configuration class:TFCamembertForMultipleChoice
(CamemBERT model)- ConvBertConfig configuration class: TFConvBertForMultipleChoice (ConvBERT model)
- DebertaV2Config configuration class: TFDebertaV2ForMultipleChoice (DeBERTa-v2 model)
DistilBertConfig
configuration class:TFDistilBertForMultipleChoice
(DistilBERT model)ElectraConfig
configuration class:TFElectraForMultipleChoice
(ELECTRA model)FlaubertConfig
configuration class:TFFlaubertForMultipleChoice
(FlauBERT model)FunnelConfig
configuration class:TFFunnelForMultipleChoice
(Funnel Transformer model)LongformerConfig
configuration class:TFLongformerForMultipleChoice
(Longformer model)MPNetConfig
configuration class:TFMPNetForMultipleChoice
(MPNet model)MobileBertConfig
configuration class:TFMobileBertForMultipleChoice
(MobileBERT model)RemBertConfig
configuration class:TFRemBertForMultipleChoice
(RemBERT model)RoFormerConfig
configuration class:TFRoFormerForMultipleChoice
(RoFormer model)RobertaConfig
configuration class:TFRobertaForMultipleChoice
(RoBERTa model)RobertaPreLayerNormConfig
configuration class:TFRobertaPreLayerNormForMultipleChoice
(RoBERTa-PreLayerNorm model)XLMConfig
configuration class:TFXLMForMultipleChoice
(XLM model)XLMRobertaConfig
configuration class:TFXLMRobertaForMultipleChoice
(XLM-RoBERTa model)XLNetConfig
configuration class:TFXLNetForMultipleChoice
(XLNet model)
- attn_implementation (
str
, optional) — The attention implementation to use in the model (if relevant). Can be any of"eager"
(manual implementation of the attention),"sdpa"
(usingF.scaled_dot_product_attention
), or"flash_attention_2"
(using Dao-AILab/flash-attention). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual"eager"
implementation.
Instantiates one of the model classes of the library (with a multiple choice head) from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the model’s configuration. Use from_pretrained() to load the model weights.
from_pretrained
< source >( *model_args **kwargs )
Parameters
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - from_pt (
bool
, optional, defaults toFalse
) — Load the model weights from a PyTorch checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). - force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download — Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. - output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). - revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. - code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the model classes of the library (with a multiple choice head) from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when it’s missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- albert —
TFAlbertForMultipleChoice
(ALBERT model) - bert — TFBertForMultipleChoice (BERT model)
- camembert —
TFCamembertForMultipleChoice
(CamemBERT model) - convbert — TFConvBertForMultipleChoice (ConvBERT model)
- deberta-v2 — TFDebertaV2ForMultipleChoice (DeBERTa-v2 model)
- distilbert —
TFDistilBertForMultipleChoice
(DistilBERT model) - electra —
TFElectraForMultipleChoice
(ELECTRA model) - flaubert —
TFFlaubertForMultipleChoice
(FlauBERT model) - funnel —
TFFunnelForMultipleChoice
(Funnel Transformer model) - longformer —
TFLongformerForMultipleChoice
(Longformer model) - mobilebert —
TFMobileBertForMultipleChoice
(MobileBERT model) - mpnet —
TFMPNetForMultipleChoice
(MPNet model) - rembert —
TFRemBertForMultipleChoice
(RemBERT model) - roberta —
TFRobertaForMultipleChoice
(RoBERTa model) - roberta-prelayernorm —
TFRobertaPreLayerNormForMultipleChoice
(RoBERTa-PreLayerNorm model) - roformer —
TFRoFormerForMultipleChoice
(RoFormer model) - xlm —
TFXLMForMultipleChoice
(XLM model) - xlm-roberta —
TFXLMRobertaForMultipleChoice
(XLM-RoBERTa model) - xlnet —
TFXLNetForMultipleChoice
(XLNet model)
Examples:
>>> from transformers import AutoConfig, TFAutoModelForMultipleChoice
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForMultipleChoice.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = TFAutoModelForMultipleChoice.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = TFAutoModelForMultipleChoice.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
FlaxAutoModelForMultipleChoice
This is a generic model class that will be instantiated as one of the model classes of the library (with a multiple choice head) when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
AlbertConfig
configuration class:FlaxAlbertForMultipleChoice
(ALBERT model)- BertConfig configuration class: FlaxBertForMultipleChoice (BERT model)
BigBirdConfig
configuration class:FlaxBigBirdForMultipleChoice
(BigBird model)DistilBertConfig
configuration class:FlaxDistilBertForMultipleChoice
(DistilBERT model)ElectraConfig
configuration class:FlaxElectraForMultipleChoice
(ELECTRA model)RoFormerConfig
configuration class:FlaxRoFormerForMultipleChoice
(RoFormer model)RobertaConfig
configuration class:FlaxRobertaForMultipleChoice
(RoBERTa model)RobertaPreLayerNormConfig
configuration class:FlaxRobertaPreLayerNormForMultipleChoice
(RoBERTa-PreLayerNorm model)XLMRobertaConfig
configuration class:FlaxXLMRobertaForMultipleChoice
(XLM-RoBERTa model)
- attn_implementation (
str
, optional) — The attention implementation to use in the model (if relevant). Can be any of"eager"
(manual implementation of the attention),"sdpa"
(usingF.scaled_dot_product_attention
), or"flash_attention_2"
(using Dao-AILab/flash-attention). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual"eager"
implementation.
Instantiates one of the model classes of the library (with a multiple choice head) from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the model’s configuration. Use from_pretrained() to load the model weights.
from_pretrained
< source >( *model_args **kwargs )
Parameters
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - from_pt (
bool
, optional, defaults toFalse
) — Load the model weights from a PyTorch checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). - force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download — Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. - output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). - revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. - code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the model classes of the library (with a multiple choice head) from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when it’s missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- albert —
FlaxAlbertForMultipleChoice
(ALBERT model) - bert — FlaxBertForMultipleChoice (BERT model)
- big_bird —
FlaxBigBirdForMultipleChoice
(BigBird model) - distilbert —
FlaxDistilBertForMultipleChoice
(DistilBERT model) - electra —
FlaxElectraForMultipleChoice
(ELECTRA model) - roberta —
FlaxRobertaForMultipleChoice
(RoBERTa model) - roberta-prelayernorm —
FlaxRobertaPreLayerNormForMultipleChoice
(RoBERTa-PreLayerNorm model) - roformer —
FlaxRoFormerForMultipleChoice
(RoFormer model) - xlm-roberta —
FlaxXLMRobertaForMultipleChoice
(XLM-RoBERTa model)
Examples:
>>> from transformers import AutoConfig, FlaxAutoModelForMultipleChoice
>>> # Download model and configuration from huggingface.co and cache.
>>> model = FlaxAutoModelForMultipleChoice.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = FlaxAutoModelForMultipleChoice.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = FlaxAutoModelForMultipleChoice.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
AutoModelForNextSentencePrediction
This is a generic model class that will be instantiated as one of the model classes of the library (with a next sentence prediction head) when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- BertConfig configuration class: BertForNextSentencePrediction (BERT model)
ErnieConfig
configuration class:ErnieForNextSentencePrediction
(ERNIE model)FNetConfig
configuration class:FNetForNextSentencePrediction
(FNet model)MegatronBertConfig
configuration class:MegatronBertForNextSentencePrediction
(Megatron-BERT model)MobileBertConfig
configuration class:MobileBertForNextSentencePrediction
(MobileBERT model)NezhaConfig
configuration class:NezhaForNextSentencePrediction
(Nezha model)QDQBertConfig
configuration class:QDQBertForNextSentencePrediction
(QDQBert model)
- attn_implementation (
str
, optional) — The attention implementation to use in the model (if relevant). Can be any of"eager"
(manual implementation of the attention),"sdpa"
(usingF.scaled_dot_product_attention
), or"flash_attention_2"
(using Dao-AILab/flash-attention). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual"eager"
implementation.
Instantiates one of the model classes of the library (with a next sentence prediction head) from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the model’s configuration. Use from_pretrained() to load the model weights.
from_pretrained
< source >( *model_args **kwargs )
Parameters
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a tensorflow index checkpoint file (e.g,
./tf_model/model.ckpt.index
). In this case,from_tf
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- state_dict (Dict[str, torch.Tensor], optional) —
A state dictionary to use instead of a state dictionary loaded from saved weights file.
This option can be used if you want to create a model from a pretrained configuration but load your own weights. In this case though, you should check if using save_pretrained() and from_pretrained() is not a simpler option.
- cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - from_tf (
bool
, optional, defaults toFalse
) — Load the model weights from a TensorFlow checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). - force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download — Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. - output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). - revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. - code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the model classes of the library (with a next sentence prediction head) from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when it’s missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- bert — BertForNextSentencePrediction (BERT model)
- ernie —
ErnieForNextSentencePrediction
(ERNIE model) - fnet —
FNetForNextSentencePrediction
(FNet model) - megatron-bert —
MegatronBertForNextSentencePrediction
(Megatron-BERT model) - mobilebert —
MobileBertForNextSentencePrediction
(MobileBERT model) - nezha —
NezhaForNextSentencePrediction
(Nezha model) - qdqbert —
QDQBertForNextSentencePrediction
(QDQBert model)
The model is set in evaluation mode by default using model.eval()
(so for instance, dropout modules are
deactivated). To train the model, you should first set it back in training mode with model.train()
Examples:
>>> from transformers import AutoConfig, AutoModelForNextSentencePrediction
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForNextSentencePrediction.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForNextSentencePrediction.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForNextSentencePrediction.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
TFAutoModelForNextSentencePrediction
This is a generic model class that will be instantiated as one of the model classes of the library (with a next sentence prediction head) when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- BertConfig configuration class: TFBertForNextSentencePrediction (BERT model)
MobileBertConfig
configuration class:TFMobileBertForNextSentencePrediction
(MobileBERT model)
- attn_implementation (
str
, optional) — The attention implementation to use in the model (if relevant). Can be any of"eager"
(manual implementation of the attention),"sdpa"
(usingF.scaled_dot_product_attention
), or"flash_attention_2"
(using Dao-AILab/flash-attention). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual"eager"
implementation.
Instantiates one of the model classes of the library (with a next sentence prediction head) from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the model’s configuration. Use from_pretrained() to load the model weights.
from_pretrained
< source >( *model_args **kwargs )
Parameters
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - from_pt (
bool
, optional, defaults toFalse
) — Load the model weights from a PyTorch checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). - force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download — Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. - output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). - revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. - code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the model classes of the library (with a next sentence prediction head) from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when it’s missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- bert — TFBertForNextSentencePrediction (BERT model)
- mobilebert —
TFMobileBertForNextSentencePrediction
(MobileBERT model)
Examples:
>>> from transformers import AutoConfig, TFAutoModelForNextSentencePrediction
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForNextSentencePrediction.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = TFAutoModelForNextSentencePrediction.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = TFAutoModelForNextSentencePrediction.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
FlaxAutoModelForNextSentencePrediction
This is a generic model class that will be instantiated as one of the model classes of the library (with a next sentence prediction head) when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- BertConfig configuration class: FlaxBertForNextSentencePrediction (BERT model)
- attn_implementation (
str
, optional) — The attention implementation to use in the model (if relevant). Can be any of"eager"
(manual implementation of the attention),"sdpa"
(usingF.scaled_dot_product_attention
), or"flash_attention_2"
(using Dao-AILab/flash-attention). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual"eager"
implementation.
Instantiates one of the model classes of the library (with a next sentence prediction head) from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the model’s configuration. Use from_pretrained() to load the model weights.
from_pretrained
< source >( *model_args **kwargs )
Parameters
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - from_pt (
bool
, optional, defaults toFalse
) — Load the model weights from a PyTorch checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). - force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download — Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. - output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). - revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. - code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the model classes of the library (with a next sentence prediction head) from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when it’s missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- bert — FlaxBertForNextSentencePrediction (BERT model)
Examples:
>>> from transformers import AutoConfig, FlaxAutoModelForNextSentencePrediction
>>> # Download model and configuration from huggingface.co and cache.
>>> model = FlaxAutoModelForNextSentencePrediction.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = FlaxAutoModelForNextSentencePrediction.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = FlaxAutoModelForNextSentencePrediction.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
AutoModelForTokenClassification
This is a generic model class that will be instantiated as one of the model classes of the library (with a token classification head) when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
AlbertConfig
configuration class:AlbertForTokenClassification
(ALBERT model)- BertConfig configuration class: BertForTokenClassification (BERT model)
BigBirdConfig
configuration class:BigBirdForTokenClassification
(BigBird model)- BioGptConfig configuration class: BioGptForTokenClassification (BioGpt model)
BloomConfig
configuration class:BloomForTokenClassification
(BLOOM model)BrosConfig
configuration class:BrosForTokenClassification
(BROS model)CamembertConfig
configuration class:CamembertForTokenClassification
(CamemBERT model)CanineConfig
configuration class:CanineForTokenClassification
(CANINE model)- ConvBertConfig configuration class: ConvBertForTokenClassification (ConvBERT model)
Data2VecTextConfig
configuration class:Data2VecTextForTokenClassification
(Data2VecText model)- DebertaConfig configuration class: DebertaForTokenClassification (DeBERTa model)
- DebertaV2Config configuration class: DebertaV2ForTokenClassification (DeBERTa-v2 model)
DiffLlamaConfig
configuration class:DiffLlamaForTokenClassification
(DiffLlama model)DistilBertConfig
configuration class:DistilBertForTokenClassification
(DistilBERT model)ElectraConfig
configuration class:ElectraForTokenClassification
(ELECTRA model)ErnieConfig
configuration class:ErnieForTokenClassification
(ERNIE model)ErnieMConfig
configuration class:ErnieMForTokenClassification
(ErnieM model)- EsmConfig configuration class: EsmForTokenClassification (ESM model)
FNetConfig
configuration class:FNetForTokenClassification
(FNet model)FalconConfig
configuration class:FalconForTokenClassification
(Falcon model)FlaubertConfig
configuration class:FlaubertForTokenClassification
(FlauBERT model)FunnelConfig
configuration class:FunnelForTokenClassification
(Funnel Transformer model)GPT2Config
configuration class:GPT2ForTokenClassification
(OpenAI GPT-2 model)GPTBigCodeConfig
configuration class:GPTBigCodeForTokenClassification
(GPTBigCode model)GPTNeoConfig
configuration class:GPTNeoForTokenClassification
(GPT Neo model)GPTNeoXConfig
configuration class:GPTNeoXForTokenClassification
(GPT NeoX model)- Gemma2Config configuration class: Gemma2ForTokenClassification (Gemma2 model)
- GemmaConfig configuration class: GemmaForTokenClassification (Gemma model)
GlmConfig
configuration class:GlmForTokenClassification
(GLM model)HeliumConfig
configuration class:HeliumForTokenClassification
(Helium model)IBertConfig
configuration class:IBertForTokenClassification
(I-BERT model)LayoutLMConfig
configuration class:LayoutLMForTokenClassification
(LayoutLM model)LayoutLMv2Config
configuration class:LayoutLMv2ForTokenClassification
(LayoutLMv2 model)LayoutLMv3Config
configuration class:LayoutLMv3ForTokenClassification
(LayoutLMv3 model)LiltConfig
configuration class:LiltForTokenClassification
(LiLT model)- LlamaConfig configuration class:
LlamaForTokenClassification
(LLaMA model) LongformerConfig
configuration class:LongformerForTokenClassification
(Longformer model)LukeConfig
configuration class:LukeForTokenClassification
(LUKE model)MPNetConfig
configuration class:MPNetForTokenClassification
(MPNet model)MT5Config
configuration class:MT5ForTokenClassification
(MT5 model)MarkupLMConfig
configuration class:MarkupLMForTokenClassification
(MarkupLM model)MegaConfig
configuration class:MegaForTokenClassification
(MEGA model)MegatronBertConfig
configuration class:MegatronBertForTokenClassification
(Megatron-BERT model)- MistralConfig configuration class: MistralForTokenClassification (Mistral model)
MixtralConfig
configuration class:MixtralForTokenClassification
(Mixtral model)MobileBertConfig
configuration class:MobileBertForTokenClassification
(MobileBERT model)ModernBertConfig
configuration class:ModernBertForTokenClassification
(ModernBERT model)MptConfig
configuration class:MptForTokenClassification
(MPT model)MraConfig
configuration class:MraForTokenClassification
(MRA model)NemotronConfig
configuration class:NemotronForTokenClassification
(Nemotron model)NezhaConfig
configuration class:NezhaForTokenClassification
(Nezha model)NystromformerConfig
configuration class:NystromformerForTokenClassification
(Nyströmformer model)PersimmonConfig
configuration class:PersimmonForTokenClassification
(Persimmon model)Phi3Config
configuration class:Phi3ForTokenClassification
(Phi3 model)PhiConfig
configuration class:PhiForTokenClassification
(Phi model)QDQBertConfig
configuration class:QDQBertForTokenClassification
(QDQBert model)Qwen2Config
configuration class:Qwen2ForTokenClassification
(Qwen2 model)Qwen2MoeConfig
configuration class:Qwen2MoeForTokenClassification
(Qwen2MoE model)RemBertConfig
configuration class:RemBertForTokenClassification
(RemBERT model)RoCBertConfig
configuration class:RoCBertForTokenClassification
(RoCBert model)RoFormerConfig
configuration class:RoFormerForTokenClassification
(RoFormer model)RobertaConfig
configuration class:RobertaForTokenClassification
(RoBERTa model)RobertaPreLayerNormConfig
configuration class:RobertaPreLayerNormForTokenClassification
(RoBERTa-PreLayerNorm model)SqueezeBertConfig
configuration class:SqueezeBertForTokenClassification
(SqueezeBERT model)StableLmConfig
configuration class:StableLmForTokenClassification
(StableLm model)Starcoder2Config
configuration class:Starcoder2ForTokenClassification
(Starcoder2 model)T5Config
configuration class:T5ForTokenClassification
(T5 model)UMT5Config
configuration class:UMT5ForTokenClassification
(UMT5 model)XLMConfig
configuration class:XLMForTokenClassification
(XLM model)XLMRobertaConfig
configuration class:XLMRobertaForTokenClassification
(XLM-RoBERTa model)XLMRobertaXLConfig
configuration class:XLMRobertaXLForTokenClassification
(XLM-RoBERTa-XL model)XLNetConfig
configuration class:XLNetForTokenClassification
(XLNet model)XmodConfig
configuration class:XmodForTokenClassification
(X-MOD model)YosoConfig
configuration class:YosoForTokenClassification
(YOSO model)
- attn_implementation (
str
, optional) — The attention implementation to use in the model (if relevant). Can be any of"eager"
(manual implementation of the attention),"sdpa"
(usingF.scaled_dot_product_attention
), or"flash_attention_2"
(using Dao-AILab/flash-attention). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual"eager"
implementation.
Instantiates one of the model classes of the library (with a token classification head) from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the model’s configuration. Use from_pretrained() to load the model weights.
from_pretrained
< source >( *model_args **kwargs )
Parameters
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a tensorflow index checkpoint file (e.g,
./tf_model/model.ckpt.index
). In this case,from_tf
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- state_dict (Dict[str, torch.Tensor], optional) —
A state dictionary to use instead of a state dictionary loaded from saved weights file.
This option can be used if you want to create a model from a pretrained configuration but load your own weights. In this case though, you should check if using save_pretrained() and from_pretrained() is not a simpler option.
- cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - from_tf (
bool
, optional, defaults toFalse
) — Load the model weights from a TensorFlow checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). - force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download — Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. - output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). - revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. - code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the model classes of the library (with a token classification head) from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when it’s missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- albert —
AlbertForTokenClassification
(ALBERT model) - bert — BertForTokenClassification (BERT model)
- big_bird —
BigBirdForTokenClassification
(BigBird model) - biogpt — BioGptForTokenClassification (BioGpt model)
- bloom —
BloomForTokenClassification
(BLOOM model) - bros —
BrosForTokenClassification
(BROS model) - camembert —
CamembertForTokenClassification
(CamemBERT model) - canine —
CanineForTokenClassification
(CANINE model) - convbert — ConvBertForTokenClassification (ConvBERT model)
- data2vec-text —
Data2VecTextForTokenClassification
(Data2VecText model) - deberta — DebertaForTokenClassification (DeBERTa model)
- deberta-v2 — DebertaV2ForTokenClassification (DeBERTa-v2 model)
- diffllama —
DiffLlamaForTokenClassification
(DiffLlama model) - distilbert —
DistilBertForTokenClassification
(DistilBERT model) - electra —
ElectraForTokenClassification
(ELECTRA model) - ernie —
ErnieForTokenClassification
(ERNIE model) - ernie_m —
ErnieMForTokenClassification
(ErnieM model) - esm — EsmForTokenClassification (ESM model)
- falcon —
FalconForTokenClassification
(Falcon model) - flaubert —
FlaubertForTokenClassification
(FlauBERT model) - fnet —
FNetForTokenClassification
(FNet model) - funnel —
FunnelForTokenClassification
(Funnel Transformer model) - gemma — GemmaForTokenClassification (Gemma model)
- gemma2 — Gemma2ForTokenClassification (Gemma2 model)
- glm —
GlmForTokenClassification
(GLM model) - gpt-sw3 —
GPT2ForTokenClassification
(GPT-Sw3 model) - gpt2 —
GPT2ForTokenClassification
(OpenAI GPT-2 model) - gpt_bigcode —
GPTBigCodeForTokenClassification
(GPTBigCode model) - gpt_neo —
GPTNeoForTokenClassification
(GPT Neo model) - gpt_neox —
GPTNeoXForTokenClassification
(GPT NeoX model) - helium —
HeliumForTokenClassification
(Helium model) - ibert —
IBertForTokenClassification
(I-BERT model) - layoutlm —
LayoutLMForTokenClassification
(LayoutLM model) - layoutlmv2 —
LayoutLMv2ForTokenClassification
(LayoutLMv2 model) - layoutlmv3 —
LayoutLMv3ForTokenClassification
(LayoutLMv3 model) - lilt —
LiltForTokenClassification
(LiLT model) - llama —
LlamaForTokenClassification
(LLaMA model) - longformer —
LongformerForTokenClassification
(Longformer model) - luke —
LukeForTokenClassification
(LUKE model) - markuplm —
MarkupLMForTokenClassification
(MarkupLM model) - mega —
MegaForTokenClassification
(MEGA model) - megatron-bert —
MegatronBertForTokenClassification
(Megatron-BERT model) - mistral — MistralForTokenClassification (Mistral model)
- mixtral —
MixtralForTokenClassification
(Mixtral model) - mobilebert —
MobileBertForTokenClassification
(MobileBERT model) - modernbert —
ModernBertForTokenClassification
(ModernBERT model) - mpnet —
MPNetForTokenClassification
(MPNet model) - mpt —
MptForTokenClassification
(MPT model) - mra —
MraForTokenClassification
(MRA model) - mt5 —
MT5ForTokenClassification
(MT5 model) - nemotron —
NemotronForTokenClassification
(Nemotron model) - nezha —
NezhaForTokenClassification
(Nezha model) - nystromformer —
NystromformerForTokenClassification
(Nyströmformer model) - persimmon —
PersimmonForTokenClassification
(Persimmon model) - phi —
PhiForTokenClassification
(Phi model) - phi3 —
Phi3ForTokenClassification
(Phi3 model) - qdqbert —
QDQBertForTokenClassification
(QDQBert model) - qwen2 —
Qwen2ForTokenClassification
(Qwen2 model) - qwen2_moe —
Qwen2MoeForTokenClassification
(Qwen2MoE model) - rembert —
RemBertForTokenClassification
(RemBERT model) - roberta —
RobertaForTokenClassification
(RoBERTa model) - roberta-prelayernorm —
RobertaPreLayerNormForTokenClassification
(RoBERTa-PreLayerNorm model) - roc_bert —
RoCBertForTokenClassification
(RoCBert model) - roformer —
RoFormerForTokenClassification
(RoFormer model) - squeezebert —
SqueezeBertForTokenClassification
(SqueezeBERT model) - stablelm —
StableLmForTokenClassification
(StableLm model) - starcoder2 —
Starcoder2ForTokenClassification
(Starcoder2 model) - t5 —
T5ForTokenClassification
(T5 model) - umt5 —
UMT5ForTokenClassification
(UMT5 model) - xlm —
XLMForTokenClassification
(XLM model) - xlm-roberta —
XLMRobertaForTokenClassification
(XLM-RoBERTa model) - xlm-roberta-xl —
XLMRobertaXLForTokenClassification
(XLM-RoBERTa-XL model) - xlnet —
XLNetForTokenClassification
(XLNet model) - xmod —
XmodForTokenClassification
(X-MOD model) - yoso —
YosoForTokenClassification
(YOSO model)
The model is set in evaluation mode by default using model.eval()
(so for instance, dropout modules are
deactivated). To train the model, you should first set it back in training mode with model.train()
Examples:
>>> from transformers import AutoConfig, AutoModelForTokenClassification
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForTokenClassification.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForTokenClassification.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForTokenClassification.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
TFAutoModelForTokenClassification
This is a generic model class that will be instantiated as one of the model classes of the library (with a token classification head) when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
AlbertConfig
configuration class:TFAlbertForTokenClassification
(ALBERT model)- BertConfig configuration class: TFBertForTokenClassification (BERT model)
CamembertConfig
configuration class:TFCamembertForTokenClassification
(CamemBERT model)- ConvBertConfig configuration class: TFConvBertForTokenClassification (ConvBERT model)
- DebertaConfig configuration class: TFDebertaForTokenClassification (DeBERTa model)
- DebertaV2Config configuration class: TFDebertaV2ForTokenClassification (DeBERTa-v2 model)
DistilBertConfig
configuration class:TFDistilBertForTokenClassification
(DistilBERT model)ElectraConfig
configuration class:TFElectraForTokenClassification
(ELECTRA model)- EsmConfig configuration class: TFEsmForTokenClassification (ESM model)
FlaubertConfig
configuration class:TFFlaubertForTokenClassification
(FlauBERT model)FunnelConfig
configuration class:TFFunnelForTokenClassification
(Funnel Transformer model)LayoutLMConfig
configuration class:TFLayoutLMForTokenClassification
(LayoutLM model)LayoutLMv3Config
configuration class:TFLayoutLMv3ForTokenClassification
(LayoutLMv3 model)LongformerConfig
configuration class:TFLongformerForTokenClassification
(Longformer model)MPNetConfig
configuration class:TFMPNetForTokenClassification
(MPNet model)MobileBertConfig
configuration class:TFMobileBertForTokenClassification
(MobileBERT model)RemBertConfig
configuration class:TFRemBertForTokenClassification
(RemBERT model)RoFormerConfig
configuration class:TFRoFormerForTokenClassification
(RoFormer model)RobertaConfig
configuration class:TFRobertaForTokenClassification
(RoBERTa model)RobertaPreLayerNormConfig
configuration class:TFRobertaPreLayerNormForTokenClassification
(RoBERTa-PreLayerNorm model)XLMConfig
configuration class:TFXLMForTokenClassification
(XLM model)XLMRobertaConfig
configuration class:TFXLMRobertaForTokenClassification
(XLM-RoBERTa model)XLNetConfig
configuration class:TFXLNetForTokenClassification
(XLNet model)
- attn_implementation (
str
, optional) — The attention implementation to use in the model (if relevant). Can be any of"eager"
(manual implementation of the attention),"sdpa"
(usingF.scaled_dot_product_attention
), or"flash_attention_2"
(using Dao-AILab/flash-attention). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual"eager"
implementation.
Instantiates one of the model classes of the library (with a token classification head) from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the model’s configuration. Use from_pretrained() to load the model weights.
from_pretrained
< source >( *model_args **kwargs )
Parameters
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - from_pt (
bool
, optional, defaults toFalse
) — Load the model weights from a PyTorch checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). - force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download — Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. - output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). - revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. - code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the model classes of the library (with a token classification head) from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when it’s missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- albert —
TFAlbertForTokenClassification
(ALBERT model) - bert — TFBertForTokenClassification (BERT model)
- camembert —
TFCamembertForTokenClassification
(CamemBERT model) - convbert — TFConvBertForTokenClassification (ConvBERT model)
- deberta — TFDebertaForTokenClassification (DeBERTa model)
- deberta-v2 — TFDebertaV2ForTokenClassification (DeBERTa-v2 model)
- distilbert —
TFDistilBertForTokenClassification
(DistilBERT model) - electra —
TFElectraForTokenClassification
(ELECTRA model) - esm — TFEsmForTokenClassification (ESM model)
- flaubert —
TFFlaubertForTokenClassification
(FlauBERT model) - funnel —
TFFunnelForTokenClassification
(Funnel Transformer model) - layoutlm —
TFLayoutLMForTokenClassification
(LayoutLM model) - layoutlmv3 —
TFLayoutLMv3ForTokenClassification
(LayoutLMv3 model) - longformer —
TFLongformerForTokenClassification
(Longformer model) - mobilebert —
TFMobileBertForTokenClassification
(MobileBERT model) - mpnet —
TFMPNetForTokenClassification
(MPNet model) - rembert —
TFRemBertForTokenClassification
(RemBERT model) - roberta —
TFRobertaForTokenClassification
(RoBERTa model) - roberta-prelayernorm —
TFRobertaPreLayerNormForTokenClassification
(RoBERTa-PreLayerNorm model) - roformer —
TFRoFormerForTokenClassification
(RoFormer model) - xlm —
TFXLMForTokenClassification
(XLM model) - xlm-roberta —
TFXLMRobertaForTokenClassification
(XLM-RoBERTa model) - xlnet —
TFXLNetForTokenClassification
(XLNet model)
Examples:
>>> from transformers import AutoConfig, TFAutoModelForTokenClassification
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForTokenClassification.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = TFAutoModelForTokenClassification.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = TFAutoModelForTokenClassification.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
FlaxAutoModelForTokenClassification
This is a generic model class that will be instantiated as one of the model classes of the library (with a token classification head) when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
AlbertConfig
configuration class:FlaxAlbertForTokenClassification
(ALBERT model)- BertConfig configuration class: FlaxBertForTokenClassification (BERT model)
BigBirdConfig
configuration class:FlaxBigBirdForTokenClassification
(BigBird model)DistilBertConfig
configuration class:FlaxDistilBertForTokenClassification
(DistilBERT model)ElectraConfig
configuration class:FlaxElectraForTokenClassification
(ELECTRA model)RoFormerConfig
configuration class:FlaxRoFormerForTokenClassification
(RoFormer model)RobertaConfig
configuration class:FlaxRobertaForTokenClassification
(RoBERTa model)RobertaPreLayerNormConfig
configuration class:FlaxRobertaPreLayerNormForTokenClassification
(RoBERTa-PreLayerNorm model)XLMRobertaConfig
configuration class:FlaxXLMRobertaForTokenClassification
(XLM-RoBERTa model)
- attn_implementation (
str
, optional) — The attention implementation to use in the model (if relevant). Can be any of"eager"
(manual implementation of the attention),"sdpa"
(usingF.scaled_dot_product_attention
), or"flash_attention_2"
(using Dao-AILab/flash-attention). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual"eager"
implementation.
Instantiates one of the model classes of the library (with a token classification head) from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the model’s configuration. Use from_pretrained() to load the model weights.
from_pretrained
< source >( *model_args **kwargs )
Parameters
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - from_pt (
bool
, optional, defaults toFalse
) — Load the model weights from a PyTorch checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). - force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download — Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. - output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). - revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. - code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the model classes of the library (with a token classification head) from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when it’s missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- albert —
FlaxAlbertForTokenClassification
(ALBERT model) - bert — FlaxBertForTokenClassification (BERT model)
- big_bird —
FlaxBigBirdForTokenClassification
(BigBird model) - distilbert —
FlaxDistilBertForTokenClassification
(DistilBERT model) - electra —
FlaxElectraForTokenClassification
(ELECTRA model) - roberta —
FlaxRobertaForTokenClassification
(RoBERTa model) - roberta-prelayernorm —
FlaxRobertaPreLayerNormForTokenClassification
(RoBERTa-PreLayerNorm model) - roformer —
FlaxRoFormerForTokenClassification
(RoFormer model) - xlm-roberta —
FlaxXLMRobertaForTokenClassification
(XLM-RoBERTa model)
Examples:
>>> from transformers import AutoConfig, FlaxAutoModelForTokenClassification
>>> # Download model and configuration from huggingface.co and cache.
>>> model = FlaxAutoModelForTokenClassification.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = FlaxAutoModelForTokenClassification.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = FlaxAutoModelForTokenClassification.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
AutoModelForQuestionAnswering
This is a generic model class that will be instantiated as one of the model classes of the library (with a question answering head) when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
AlbertConfig
configuration class:AlbertForQuestionAnswering
(ALBERT model)- BartConfig configuration class: BartForQuestionAnswering (BART model)
- BertConfig configuration class: BertForQuestionAnswering (BERT model)
BigBirdConfig
configuration class:BigBirdForQuestionAnswering
(BigBird model)BigBirdPegasusConfig
configuration class:BigBirdPegasusForQuestionAnswering
(BigBird-Pegasus model)BloomConfig
configuration class:BloomForQuestionAnswering
(BLOOM model)CamembertConfig
configuration class:CamembertForQuestionAnswering
(CamemBERT model)CanineConfig
configuration class:CanineForQuestionAnswering
(CANINE model)- ConvBertConfig configuration class: ConvBertForQuestionAnswering (ConvBERT model)
Data2VecTextConfig
configuration class:Data2VecTextForQuestionAnswering
(Data2VecText model)- DebertaConfig configuration class: DebertaForQuestionAnswering (DeBERTa model)
- DebertaV2Config configuration class: DebertaV2ForQuestionAnswering (DeBERTa-v2 model)
DiffLlamaConfig
configuration class:DiffLlamaForQuestionAnswering
(DiffLlama model)DistilBertConfig
configuration class:DistilBertForQuestionAnswering
(DistilBERT model)ElectraConfig
configuration class:ElectraForQuestionAnswering
(ELECTRA model)ErnieConfig
configuration class:ErnieForQuestionAnswering
(ERNIE model)ErnieMConfig
configuration class:ErnieMForQuestionAnswering
(ErnieM model)FNetConfig
configuration class:FNetForQuestionAnswering
(FNet model)FalconConfig
configuration class:FalconForQuestionAnswering
(Falcon model)FlaubertConfig
configuration class:FlaubertForQuestionAnsweringSimple
(FlauBERT model)FunnelConfig
configuration class:FunnelForQuestionAnswering
(Funnel Transformer model)GPT2Config
configuration class:GPT2ForQuestionAnswering
(OpenAI GPT-2 model)GPTJConfig
configuration class:GPTJForQuestionAnswering
(GPT-J model)GPTNeoConfig
configuration class:GPTNeoForQuestionAnswering
(GPT Neo model)GPTNeoXConfig
configuration class:GPTNeoXForQuestionAnswering
(GPT NeoX model)IBertConfig
configuration class:IBertForQuestionAnswering
(I-BERT model)LEDConfig
configuration class:LEDForQuestionAnswering
(LED model)LayoutLMv2Config
configuration class:LayoutLMv2ForQuestionAnswering
(LayoutLMv2 model)LayoutLMv3Config
configuration class:LayoutLMv3ForQuestionAnswering
(LayoutLMv3 model)LiltConfig
configuration class:LiltForQuestionAnswering
(LiLT model)- LlamaConfig configuration class:
LlamaForQuestionAnswering
(LLaMA model) LongformerConfig
configuration class:LongformerForQuestionAnswering
(Longformer model)LukeConfig
configuration class:LukeForQuestionAnswering
(LUKE model)LxmertConfig
configuration class:LxmertForQuestionAnswering
(LXMERT model)MBartConfig
configuration class:MBartForQuestionAnswering
(mBART model)MPNetConfig
configuration class:MPNetForQuestionAnswering
(MPNet model)MT5Config
configuration class:MT5ForQuestionAnswering
(MT5 model)MarkupLMConfig
configuration class:MarkupLMForQuestionAnswering
(MarkupLM model)MegaConfig
configuration class:MegaForQuestionAnswering
(MEGA model)MegatronBertConfig
configuration class:MegatronBertForQuestionAnswering
(Megatron-BERT model)- MistralConfig configuration class:
MistralForQuestionAnswering
(Mistral model) MixtralConfig
configuration class:MixtralForQuestionAnswering
(Mixtral model)MobileBertConfig
configuration class:MobileBertForQuestionAnswering
(MobileBERT model)MptConfig
configuration class:MptForQuestionAnswering
(MPT model)MraConfig
configuration class:MraForQuestionAnswering
(MRA model)MvpConfig
configuration class:MvpForQuestionAnswering
(MVP model)NemotronConfig
configuration class:NemotronForQuestionAnswering
(Nemotron model)NezhaConfig
configuration class:NezhaForQuestionAnswering
(Nezha model)NystromformerConfig
configuration class:NystromformerForQuestionAnswering
(Nyströmformer model)OPTConfig
configuration class:OPTForQuestionAnswering
(OPT model)QDQBertConfig
configuration class:QDQBertForQuestionAnswering
(QDQBert model)Qwen2Config
configuration class:Qwen2ForQuestionAnswering
(Qwen2 model)Qwen2MoeConfig
configuration class:Qwen2MoeForQuestionAnswering
(Qwen2MoE model)ReformerConfig
configuration class:ReformerForQuestionAnswering
(Reformer model)RemBertConfig
configuration class:RemBertForQuestionAnswering
(RemBERT model)RoCBertConfig
configuration class:RoCBertForQuestionAnswering
(RoCBert model)RoFormerConfig
configuration class:RoFormerForQuestionAnswering
(RoFormer model)RobertaConfig
configuration class:RobertaForQuestionAnswering
(RoBERTa model)RobertaPreLayerNormConfig
configuration class:RobertaPreLayerNormForQuestionAnswering
(RoBERTa-PreLayerNorm model)SplinterConfig
configuration class:SplinterForQuestionAnswering
(Splinter model)SqueezeBertConfig
configuration class:SqueezeBertForQuestionAnswering
(SqueezeBERT model)T5Config
configuration class:T5ForQuestionAnswering
(T5 model)UMT5Config
configuration class:UMT5ForQuestionAnswering
(UMT5 model)XLMConfig
configuration class:XLMForQuestionAnsweringSimple
(XLM model)XLMRobertaConfig
configuration class:XLMRobertaForQuestionAnswering
(XLM-RoBERTa model)XLMRobertaXLConfig
configuration class:XLMRobertaXLForQuestionAnswering
(XLM-RoBERTa-XL model)XLNetConfig
configuration class:XLNetForQuestionAnsweringSimple
(XLNet model)XmodConfig
configuration class:XmodForQuestionAnswering
(X-MOD model)YosoConfig
configuration class:YosoForQuestionAnswering
(YOSO model)
- attn_implementation (
str
, optional) — The attention implementation to use in the model (if relevant). Can be any of"eager"
(manual implementation of the attention),"sdpa"
(usingF.scaled_dot_product_attention
), or"flash_attention_2"
(using Dao-AILab/flash-attention). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual"eager"
implementation.
Instantiates one of the model classes of the library (with a question answering head) from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the model’s configuration. Use from_pretrained() to load the model weights.
from_pretrained
< source >( *model_args **kwargs )
Parameters
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a tensorflow index checkpoint file (e.g,
./tf_model/model.ckpt.index
). In this case,from_tf
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- state_dict (Dict[str, torch.Tensor], optional) —
A state dictionary to use instead of a state dictionary loaded from saved weights file.
This option can be used if you want to create a model from a pretrained configuration but load your own weights. In this case though, you should check if using save_pretrained() and from_pretrained() is not a simpler option.
- cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - from_tf (
bool
, optional, defaults toFalse
) — Load the model weights from a TensorFlow checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). - force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download — Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. - output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). - revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. - code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the model classes of the library (with a question answering head) from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when it’s missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- albert —
AlbertForQuestionAnswering
(ALBERT model) - bart — BartForQuestionAnswering (BART model)
- bert — BertForQuestionAnswering (BERT model)
- big_bird —
BigBirdForQuestionAnswering
(BigBird model) - bigbird_pegasus —
BigBirdPegasusForQuestionAnswering
(BigBird-Pegasus model) - bloom —
BloomForQuestionAnswering
(BLOOM model) - camembert —
CamembertForQuestionAnswering
(CamemBERT model) - canine —
CanineForQuestionAnswering
(CANINE model) - convbert — ConvBertForQuestionAnswering (ConvBERT model)
- data2vec-text —
Data2VecTextForQuestionAnswering
(Data2VecText model) - deberta — DebertaForQuestionAnswering (DeBERTa model)
- deberta-v2 — DebertaV2ForQuestionAnswering (DeBERTa-v2 model)
- diffllama —
DiffLlamaForQuestionAnswering
(DiffLlama model) - distilbert —
DistilBertForQuestionAnswering
(DistilBERT model) - electra —
ElectraForQuestionAnswering
(ELECTRA model) - ernie —
ErnieForQuestionAnswering
(ERNIE model) - ernie_m —
ErnieMForQuestionAnswering
(ErnieM model) - falcon —
FalconForQuestionAnswering
(Falcon model) - flaubert —
FlaubertForQuestionAnsweringSimple
(FlauBERT model) - fnet —
FNetForQuestionAnswering
(FNet model) - funnel —
FunnelForQuestionAnswering
(Funnel Transformer model) - gpt2 —
GPT2ForQuestionAnswering
(OpenAI GPT-2 model) - gpt_neo —
GPTNeoForQuestionAnswering
(GPT Neo model) - gpt_neox —
GPTNeoXForQuestionAnswering
(GPT NeoX model) - gptj —
GPTJForQuestionAnswering
(GPT-J model) - ibert —
IBertForQuestionAnswering
(I-BERT model) - layoutlmv2 —
LayoutLMv2ForQuestionAnswering
(LayoutLMv2 model) - layoutlmv3 —
LayoutLMv3ForQuestionAnswering
(LayoutLMv3 model) - led —
LEDForQuestionAnswering
(LED model) - lilt —
LiltForQuestionAnswering
(LiLT model) - llama —
LlamaForQuestionAnswering
(LLaMA model) - longformer —
LongformerForQuestionAnswering
(Longformer model) - luke —
LukeForQuestionAnswering
(LUKE model) - lxmert —
LxmertForQuestionAnswering
(LXMERT model) - markuplm —
MarkupLMForQuestionAnswering
(MarkupLM model) - mbart —
MBartForQuestionAnswering
(mBART model) - mega —
MegaForQuestionAnswering
(MEGA model) - megatron-bert —
MegatronBertForQuestionAnswering
(Megatron-BERT model) - mistral —
MistralForQuestionAnswering
(Mistral model) - mixtral —
MixtralForQuestionAnswering
(Mixtral model) - mobilebert —
MobileBertForQuestionAnswering
(MobileBERT model) - mpnet —
MPNetForQuestionAnswering
(MPNet model) - mpt —
MptForQuestionAnswering
(MPT model) - mra —
MraForQuestionAnswering
(MRA model) - mt5 —
MT5ForQuestionAnswering
(MT5 model) - mvp —
MvpForQuestionAnswering
(MVP model) - nemotron —
NemotronForQuestionAnswering
(Nemotron model) - nezha —
NezhaForQuestionAnswering
(Nezha model) - nystromformer —
NystromformerForQuestionAnswering
(Nyströmformer model) - opt —
OPTForQuestionAnswering
(OPT model) - qdqbert —
QDQBertForQuestionAnswering
(QDQBert model) - qwen2 —
Qwen2ForQuestionAnswering
(Qwen2 model) - qwen2_moe —
Qwen2MoeForQuestionAnswering
(Qwen2MoE model) - reformer —
ReformerForQuestionAnswering
(Reformer model) - rembert —
RemBertForQuestionAnswering
(RemBERT model) - roberta —
RobertaForQuestionAnswering
(RoBERTa model) - roberta-prelayernorm —
RobertaPreLayerNormForQuestionAnswering
(RoBERTa-PreLayerNorm model) - roc_bert —
RoCBertForQuestionAnswering
(RoCBert model) - roformer —
RoFormerForQuestionAnswering
(RoFormer model) - splinter —
SplinterForQuestionAnswering
(Splinter model) - squeezebert —
SqueezeBertForQuestionAnswering
(SqueezeBERT model) - t5 —
T5ForQuestionAnswering
(T5 model) - umt5 —
UMT5ForQuestionAnswering
(UMT5 model) - xlm —
XLMForQuestionAnsweringSimple
(XLM model) - xlm-roberta —
XLMRobertaForQuestionAnswering
(XLM-RoBERTa model) - xlm-roberta-xl —
XLMRobertaXLForQuestionAnswering
(XLM-RoBERTa-XL model) - xlnet —
XLNetForQuestionAnsweringSimple
(XLNet model) - xmod —
XmodForQuestionAnswering
(X-MOD model) - yoso —
YosoForQuestionAnswering
(YOSO model)
The model is set in evaluation mode by default using model.eval()
(so for instance, dropout modules are
deactivated). To train the model, you should first set it back in training mode with model.train()
Examples:
>>> from transformers import AutoConfig, AutoModelForQuestionAnswering
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForQuestionAnswering.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForQuestionAnswering.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForQuestionAnswering.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
TFAutoModelForQuestionAnswering
This is a generic model class that will be instantiated as one of the model classes of the library (with a question answering head) when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
AlbertConfig
configuration class:TFAlbertForQuestionAnswering
(ALBERT model)- BertConfig configuration class: TFBertForQuestionAnswering (BERT model)
CamembertConfig
configuration class:TFCamembertForQuestionAnswering
(CamemBERT model)- ConvBertConfig configuration class: TFConvBertForQuestionAnswering (ConvBERT model)
- DebertaConfig configuration class: TFDebertaForQuestionAnswering (DeBERTa model)
- DebertaV2Config configuration class: TFDebertaV2ForQuestionAnswering (DeBERTa-v2 model)
DistilBertConfig
configuration class:TFDistilBertForQuestionAnswering
(DistilBERT model)ElectraConfig
configuration class:TFElectraForQuestionAnswering
(ELECTRA model)FlaubertConfig
configuration class:TFFlaubertForQuestionAnsweringSimple
(FlauBERT model)FunnelConfig
configuration class:TFFunnelForQuestionAnswering
(Funnel Transformer model)GPTJConfig
configuration class:TFGPTJForQuestionAnswering
(GPT-J model)LayoutLMv3Config
configuration class:TFLayoutLMv3ForQuestionAnswering
(LayoutLMv3 model)LongformerConfig
configuration class:TFLongformerForQuestionAnswering
(Longformer model)MPNetConfig
configuration class:TFMPNetForQuestionAnswering
(MPNet model)MobileBertConfig
configuration class:TFMobileBertForQuestionAnswering
(MobileBERT model)RemBertConfig
configuration class:TFRemBertForQuestionAnswering
(RemBERT model)RoFormerConfig
configuration class:TFRoFormerForQuestionAnswering
(RoFormer model)RobertaConfig
configuration class:TFRobertaForQuestionAnswering
(RoBERTa model)RobertaPreLayerNormConfig
configuration class:TFRobertaPreLayerNormForQuestionAnswering
(RoBERTa-PreLayerNorm model)XLMConfig
configuration class:TFXLMForQuestionAnsweringSimple
(XLM model)XLMRobertaConfig
configuration class:TFXLMRobertaForQuestionAnswering
(XLM-RoBERTa model)XLNetConfig
configuration class:TFXLNetForQuestionAnsweringSimple
(XLNet model)
- attn_implementation (
str
, optional) — The attention implementation to use in the model (if relevant). Can be any of"eager"
(manual implementation of the attention),"sdpa"
(usingF.scaled_dot_product_attention
), or"flash_attention_2"
(using Dao-AILab/flash-attention). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual"eager"
implementation.
Instantiates one of the model classes of the library (with a question answering head) from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the model’s configuration. Use from_pretrained() to load the model weights.
from_pretrained
< source >( *model_args **kwargs )
Parameters
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - from_pt (
bool
, optional, defaults toFalse
) — Load the model weights from a PyTorch checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). - force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download — Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. - output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). - revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. - code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the model classes of the library (with a question answering head) from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when it’s missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- albert —
TFAlbertForQuestionAnswering
(ALBERT model) - bert — TFBertForQuestionAnswering (BERT model)
- camembert —
TFCamembertForQuestionAnswering
(CamemBERT model) - convbert — TFConvBertForQuestionAnswering (ConvBERT model)
- deberta — TFDebertaForQuestionAnswering (DeBERTa model)
- deberta-v2 — TFDebertaV2ForQuestionAnswering (DeBERTa-v2 model)
- distilbert —
TFDistilBertForQuestionAnswering
(DistilBERT model) - electra —
TFElectraForQuestionAnswering
(ELECTRA model) - flaubert —
TFFlaubertForQuestionAnsweringSimple
(FlauBERT model) - funnel —
TFFunnelForQuestionAnswering
(Funnel Transformer model) - gptj —
TFGPTJForQuestionAnswering
(GPT-J model) - layoutlmv3 —
TFLayoutLMv3ForQuestionAnswering
(LayoutLMv3 model) - longformer —
TFLongformerForQuestionAnswering
(Longformer model) - mobilebert —
TFMobileBertForQuestionAnswering
(MobileBERT model) - mpnet —
TFMPNetForQuestionAnswering
(MPNet model) - rembert —
TFRemBertForQuestionAnswering
(RemBERT model) - roberta —
TFRobertaForQuestionAnswering
(RoBERTa model) - roberta-prelayernorm —
TFRobertaPreLayerNormForQuestionAnswering
(RoBERTa-PreLayerNorm model) - roformer —
TFRoFormerForQuestionAnswering
(RoFormer model) - xlm —
TFXLMForQuestionAnsweringSimple
(XLM model) - xlm-roberta —
TFXLMRobertaForQuestionAnswering
(XLM-RoBERTa model) - xlnet —
TFXLNetForQuestionAnsweringSimple
(XLNet model)
Examples:
>>> from transformers import AutoConfig, TFAutoModelForQuestionAnswering
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForQuestionAnswering.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = TFAutoModelForQuestionAnswering.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = TFAutoModelForQuestionAnswering.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
FlaxAutoModelForQuestionAnswering
This is a generic model class that will be instantiated as one of the model classes of the library (with a question answering head) when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
AlbertConfig
configuration class:FlaxAlbertForQuestionAnswering
(ALBERT model)- BartConfig configuration class: FlaxBartForQuestionAnswering (BART model)
- BertConfig configuration class: FlaxBertForQuestionAnswering (BERT model)
BigBirdConfig
configuration class:FlaxBigBirdForQuestionAnswering
(BigBird model)DistilBertConfig
configuration class:FlaxDistilBertForQuestionAnswering
(DistilBERT model)ElectraConfig
configuration class:FlaxElectraForQuestionAnswering
(ELECTRA model)MBartConfig
configuration class:FlaxMBartForQuestionAnswering
(mBART model)RoFormerConfig
configuration class:FlaxRoFormerForQuestionAnswering
(RoFormer model)RobertaConfig
configuration class:FlaxRobertaForQuestionAnswering
(RoBERTa model)RobertaPreLayerNormConfig
configuration class:FlaxRobertaPreLayerNormForQuestionAnswering
(RoBERTa-PreLayerNorm model)XLMRobertaConfig
configuration class:FlaxXLMRobertaForQuestionAnswering
(XLM-RoBERTa model)
- attn_implementation (
str
, optional) — The attention implementation to use in the model (if relevant). Can be any of"eager"
(manual implementation of the attention),"sdpa"
(usingF.scaled_dot_product_attention
), or"flash_attention_2"
(using Dao-AILab/flash-attention). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual"eager"
implementation.
Instantiates one of the model classes of the library (with a question answering head) from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the model’s configuration. Use from_pretrained() to load the model weights.
from_pretrained
< source >( *model_args **kwargs )
Parameters
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - from_pt (
bool
, optional, defaults toFalse
) — Load the model weights from a PyTorch checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). - force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download — Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. - output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). - revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. - code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the model classes of the library (with a question answering head) from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when it’s missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- albert —
FlaxAlbertForQuestionAnswering
(ALBERT model) - bart — FlaxBartForQuestionAnswering (BART model)
- bert — FlaxBertForQuestionAnswering (BERT model)
- big_bird —
FlaxBigBirdForQuestionAnswering
(BigBird model) - distilbert —
FlaxDistilBertForQuestionAnswering
(DistilBERT model) - electra —
FlaxElectraForQuestionAnswering
(ELECTRA model) - mbart —
FlaxMBartForQuestionAnswering
(mBART model) - roberta —
FlaxRobertaForQuestionAnswering
(RoBERTa model) - roberta-prelayernorm —
FlaxRobertaPreLayerNormForQuestionAnswering
(RoBERTa-PreLayerNorm model) - roformer —
FlaxRoFormerForQuestionAnswering
(RoFormer model) - xlm-roberta —
FlaxXLMRobertaForQuestionAnswering
(XLM-RoBERTa model)
Examples:
>>> from transformers import AutoConfig, FlaxAutoModelForQuestionAnswering
>>> # Download model and configuration from huggingface.co and cache.
>>> model = FlaxAutoModelForQuestionAnswering.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = FlaxAutoModelForQuestionAnswering.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = FlaxAutoModelForQuestionAnswering.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
AutoModelForTextEncoding
TFAutoModelForTextEncoding
컴퓨터 비전
다음 자동 클래스들은 아래의 컴퓨터 비전 작업에 사용할 수 있습니다.
AutoModelForDepthEstimation
This is a generic model class that will be instantiated as one of the model classes of the library (with a depth estimation head) when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
DPTConfig
configuration class:DPTForDepthEstimation
(DPT model)DepthAnythingConfig
configuration class:DepthAnythingForDepthEstimation
(Depth Anything model)GLPNConfig
configuration class:GLPNForDepthEstimation
(GLPN model)ZoeDepthConfig
configuration class:ZoeDepthForDepthEstimation
(ZoeDepth model)
- attn_implementation (
str
, optional) — The attention implementation to use in the model (if relevant). Can be any of"eager"
(manual implementation of the attention),"sdpa"
(usingF.scaled_dot_product_attention
), or"flash_attention_2"
(using Dao-AILab/flash-attention). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual"eager"
implementation.
Instantiates one of the model classes of the library (with a depth estimation head) from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the model’s configuration. Use from_pretrained() to load the model weights.
from_pretrained
< source >( *model_args **kwargs )
Parameters
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a tensorflow index checkpoint file (e.g,
./tf_model/model.ckpt.index
). In this case,from_tf
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- state_dict (Dict[str, torch.Tensor], optional) —
A state dictionary to use instead of a state dictionary loaded from saved weights file.
This option can be used if you want to create a model from a pretrained configuration but load your own weights. In this case though, you should check if using save_pretrained() and from_pretrained() is not a simpler option.
- cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - from_tf (
bool
, optional, defaults toFalse
) — Load the model weights from a TensorFlow checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). - force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download — Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. - output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). - revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. - code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the model classes of the library (with a depth estimation head) from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when it’s missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- depth_anything —
DepthAnythingForDepthEstimation
(Depth Anything model) - dpt —
DPTForDepthEstimation
(DPT model) - glpn —
GLPNForDepthEstimation
(GLPN model) - zoedepth —
ZoeDepthForDepthEstimation
(ZoeDepth model)
The model is set in evaluation mode by default using model.eval()
(so for instance, dropout modules are
deactivated). To train the model, you should first set it back in training mode with model.train()
Examples:
>>> from transformers import AutoConfig, AutoModelForDepthEstimation
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForDepthEstimation.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForDepthEstimation.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForDepthEstimation.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
AutoModelForImageClassification
This is a generic model class that will be instantiated as one of the model classes of the library (with a image classification head) when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
BeitConfig
configuration class:BeitForImageClassification
(BEiT model)BitConfig
configuration class:BitForImageClassification
(BiT model)- CLIPConfig configuration class: CLIPForImageClassification (CLIP model)
ConvNextConfig
configuration class:ConvNextForImageClassification
(ConvNeXT model)ConvNextV2Config
configuration class:ConvNextV2ForImageClassification
(ConvNeXTV2 model)CvtConfig
configuration class:CvtForImageClassification
(CvT model)Data2VecVisionConfig
configuration class:Data2VecVisionForImageClassification
(Data2VecVision model)DeiTConfig
configuration class:DeiTForImageClassification
orDeiTForImageClassificationWithTeacher
(DeiT model)DinatConfig
configuration class:DinatForImageClassification
(DiNAT model)Dinov2Config
configuration class:Dinov2ForImageClassification
(DINOv2 model)Dinov2WithRegistersConfig
configuration class:Dinov2WithRegistersForImageClassification
(DINOv2 with Registers model)EfficientFormerConfig
configuration class:EfficientFormerForImageClassification
orEfficientFormerForImageClassificationWithTeacher
(EfficientFormer model)EfficientNetConfig
configuration class:EfficientNetForImageClassification
(EfficientNet model)FocalNetConfig
configuration class:FocalNetForImageClassification
(FocalNet model)HieraConfig
configuration class:HieraForImageClassification
(Hiera model)IJepaConfig
configuration class:IJepaForImageClassification
(I-JEPA model)ImageGPTConfig
configuration class:ImageGPTForImageClassification
(ImageGPT model)LevitConfig
configuration class:LevitForImageClassification
orLevitForImageClassificationWithTeacher
(LeViT model)MobileNetV1Config
configuration class:MobileNetV1ForImageClassification
(MobileNetV1 model)MobileNetV2Config
configuration class:MobileNetV2ForImageClassification
(MobileNetV2 model)MobileViTConfig
configuration class:MobileViTForImageClassification
(MobileViT model)MobileViTV2Config
configuration class:MobileViTV2ForImageClassification
(MobileViTV2 model)NatConfig
configuration class:NatForImageClassification
(NAT model)PerceiverConfig
configuration class:PerceiverForImageClassificationLearned
orPerceiverForImageClassificationFourier
orPerceiverForImageClassificationConvProcessing
(Perceiver model)PoolFormerConfig
configuration class:PoolFormerForImageClassification
(PoolFormer model)PvtConfig
configuration class:PvtForImageClassification
(PVT model)PvtV2Config
configuration class:PvtV2ForImageClassification
(PVTv2 model)RegNetConfig
configuration class:RegNetForImageClassification
(RegNet model)ResNetConfig
configuration class:ResNetForImageClassification
(ResNet model)SegformerConfig
configuration class:SegformerForImageClassification
(SegFormer model)SiglipConfig
configuration class:SiglipForImageClassification
(SigLIP model)SwiftFormerConfig
configuration class:SwiftFormerForImageClassification
(SwiftFormer model)- SwinConfig configuration class: SwinForImageClassification (Swin Transformer model)
- Swinv2Config configuration class: Swinv2ForImageClassification (Swin Transformer V2 model)
TextNetConfig
configuration class:TextNetForImageClassification
(TextNet model)TimmWrapperConfig
configuration class:TimmWrapperForImageClassification
(TimmWrapperModel model)VanConfig
configuration class:VanForImageClassification
(VAN model)- ViTConfig configuration class: ViTForImageClassification (ViT model)
ViTHybridConfig
configuration class:ViTHybridForImageClassification
(ViT Hybrid model)ViTMSNConfig
configuration class:ViTMSNForImageClassification
(ViTMSN model)
- attn_implementation (
str
, optional) — The attention implementation to use in the model (if relevant). Can be any of"eager"
(manual implementation of the attention),"sdpa"
(usingF.scaled_dot_product_attention
), or"flash_attention_2"
(using Dao-AILab/flash-attention). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual"eager"
implementation.
Instantiates one of the model classes of the library (with a image classification head) from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the model’s configuration. Use from_pretrained() to load the model weights.
from_pretrained
< source >( *model_args **kwargs )
Parameters
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a tensorflow index checkpoint file (e.g,
./tf_model/model.ckpt.index
). In this case,from_tf
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- state_dict (Dict[str, torch.Tensor], optional) —
A state dictionary to use instead of a state dictionary loaded from saved weights file.
This option can be used if you want to create a model from a pretrained configuration but load your own weights. In this case though, you should check if using save_pretrained() and from_pretrained() is not a simpler option.
- cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - from_tf (
bool
, optional, defaults toFalse
) — Load the model weights from a TensorFlow checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). - force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download — Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. - output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). - revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. - code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the model classes of the library (with a image classification head) from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when it’s missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- beit —
BeitForImageClassification
(BEiT model) - bit —
BitForImageClassification
(BiT model) - clip — CLIPForImageClassification (CLIP model)
- convnext —
ConvNextForImageClassification
(ConvNeXT model) - convnextv2 —
ConvNextV2ForImageClassification
(ConvNeXTV2 model) - cvt —
CvtForImageClassification
(CvT model) - data2vec-vision —
Data2VecVisionForImageClassification
(Data2VecVision model) - deit —
DeiTForImageClassification
orDeiTForImageClassificationWithTeacher
(DeiT model) - dinat —
DinatForImageClassification
(DiNAT model) - dinov2 —
Dinov2ForImageClassification
(DINOv2 model) - dinov2_with_registers —
Dinov2WithRegistersForImageClassification
(DINOv2 with Registers model) - efficientformer —
EfficientFormerForImageClassification
orEfficientFormerForImageClassificationWithTeacher
(EfficientFormer model) - efficientnet —
EfficientNetForImageClassification
(EfficientNet model) - focalnet —
FocalNetForImageClassification
(FocalNet model) - hiera —
HieraForImageClassification
(Hiera model) - ijepa —
IJepaForImageClassification
(I-JEPA model) - imagegpt —
ImageGPTForImageClassification
(ImageGPT model) - levit —
LevitForImageClassification
orLevitForImageClassificationWithTeacher
(LeViT model) - mobilenet_v1 —
MobileNetV1ForImageClassification
(MobileNetV1 model) - mobilenet_v2 —
MobileNetV2ForImageClassification
(MobileNetV2 model) - mobilevit —
MobileViTForImageClassification
(MobileViT model) - mobilevitv2 —
MobileViTV2ForImageClassification
(MobileViTV2 model) - nat —
NatForImageClassification
(NAT model) - perceiver —
PerceiverForImageClassificationLearned
orPerceiverForImageClassificationFourier
orPerceiverForImageClassificationConvProcessing
(Perceiver model) - poolformer —
PoolFormerForImageClassification
(PoolFormer model) - pvt —
PvtForImageClassification
(PVT model) - pvt_v2 —
PvtV2ForImageClassification
(PVTv2 model) - regnet —
RegNetForImageClassification
(RegNet model) - resnet —
ResNetForImageClassification
(ResNet model) - segformer —
SegformerForImageClassification
(SegFormer model) - siglip —
SiglipForImageClassification
(SigLIP model) - swiftformer —
SwiftFormerForImageClassification
(SwiftFormer model) - swin — SwinForImageClassification (Swin Transformer model)
- swinv2 — Swinv2ForImageClassification (Swin Transformer V2 model)
- textnet —
TextNetForImageClassification
(TextNet model) - timm_wrapper —
TimmWrapperForImageClassification
(TimmWrapperModel model) - van —
VanForImageClassification
(VAN model) - vit — ViTForImageClassification (ViT model)
- vit_hybrid —
ViTHybridForImageClassification
(ViT Hybrid model) - vit_msn —
ViTMSNForImageClassification
(ViTMSN model)
The model is set in evaluation mode by default using model.eval()
(so for instance, dropout modules are
deactivated). To train the model, you should first set it back in training mode with model.train()
Examples:
>>> from transformers import AutoConfig, AutoModelForImageClassification
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForImageClassification.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForImageClassification.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForImageClassification.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
TFAutoModelForImageClassification
This is a generic model class that will be instantiated as one of the model classes of the library (with a image classification head) when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
ConvNextConfig
configuration class:TFConvNextForImageClassification
(ConvNeXT model)ConvNextV2Config
configuration class:TFConvNextV2ForImageClassification
(ConvNeXTV2 model)CvtConfig
configuration class:TFCvtForImageClassification
(CvT model)Data2VecVisionConfig
configuration class:TFData2VecVisionForImageClassification
(Data2VecVision model)DeiTConfig
configuration class:TFDeiTForImageClassification
orTFDeiTForImageClassificationWithTeacher
(DeiT model)EfficientFormerConfig
configuration class:TFEfficientFormerForImageClassification
orTFEfficientFormerForImageClassificationWithTeacher
(EfficientFormer model)MobileViTConfig
configuration class:TFMobileViTForImageClassification
(MobileViT model)RegNetConfig
configuration class:TFRegNetForImageClassification
(RegNet model)ResNetConfig
configuration class:TFResNetForImageClassification
(ResNet model)SegformerConfig
configuration class:TFSegformerForImageClassification
(SegFormer model)SwiftFormerConfig
configuration class:TFSwiftFormerForImageClassification
(SwiftFormer model)- SwinConfig configuration class: TFSwinForImageClassification (Swin Transformer model)
- ViTConfig configuration class: TFViTForImageClassification (ViT model)
- attn_implementation (
str
, optional) — The attention implementation to use in the model (if relevant). Can be any of"eager"
(manual implementation of the attention),"sdpa"
(usingF.scaled_dot_product_attention
), or"flash_attention_2"
(using Dao-AILab/flash-attention). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual"eager"
implementation.
Instantiates one of the model classes of the library (with a image classification head) from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the model’s configuration. Use from_pretrained() to load the model weights.
from_pretrained
< source >( *model_args **kwargs )
Parameters
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - from_pt (
bool
, optional, defaults toFalse
) — Load the model weights from a PyTorch checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). - force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download — Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. - output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). - revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. - code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the model classes of the library (with a image classification head) from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when it’s missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- convnext —
TFConvNextForImageClassification
(ConvNeXT model) - convnextv2 —
TFConvNextV2ForImageClassification
(ConvNeXTV2 model) - cvt —
TFCvtForImageClassification
(CvT model) - data2vec-vision —
TFData2VecVisionForImageClassification
(Data2VecVision model) - deit —
TFDeiTForImageClassification
orTFDeiTForImageClassificationWithTeacher
(DeiT model) - efficientformer —
TFEfficientFormerForImageClassification
orTFEfficientFormerForImageClassificationWithTeacher
(EfficientFormer model) - mobilevit —
TFMobileViTForImageClassification
(MobileViT model) - regnet —
TFRegNetForImageClassification
(RegNet model) - resnet —
TFResNetForImageClassification
(ResNet model) - segformer —
TFSegformerForImageClassification
(SegFormer model) - swiftformer —
TFSwiftFormerForImageClassification
(SwiftFormer model) - swin — TFSwinForImageClassification (Swin Transformer model)
- vit — TFViTForImageClassification (ViT model)
Examples:
>>> from transformers import AutoConfig, TFAutoModelForImageClassification
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForImageClassification.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = TFAutoModelForImageClassification.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = TFAutoModelForImageClassification.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
FlaxAutoModelForImageClassification
This is a generic model class that will be instantiated as one of the model classes of the library (with a image classification head) when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
BeitConfig
configuration class:FlaxBeitForImageClassification
(BEiT model)Dinov2Config
configuration class:FlaxDinov2ForImageClassification
(DINOv2 model)RegNetConfig
configuration class:FlaxRegNetForImageClassification
(RegNet model)ResNetConfig
configuration class:FlaxResNetForImageClassification
(ResNet model)- ViTConfig configuration class: FlaxViTForImageClassification (ViT model)
- attn_implementation (
str
, optional) — The attention implementation to use in the model (if relevant). Can be any of"eager"
(manual implementation of the attention),"sdpa"
(usingF.scaled_dot_product_attention
), or"flash_attention_2"
(using Dao-AILab/flash-attention). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual"eager"
implementation.
Instantiates one of the model classes of the library (with a image classification head) from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the model’s configuration. Use from_pretrained() to load the model weights.
from_pretrained
< source >( *model_args **kwargs )
Parameters
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - from_pt (
bool
, optional, defaults toFalse
) — Load the model weights from a PyTorch checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). - force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download — Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. - output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). - revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. - code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the model classes of the library (with a image classification head) from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when it’s missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- beit —
FlaxBeitForImageClassification
(BEiT model) - dinov2 —
FlaxDinov2ForImageClassification
(DINOv2 model) - regnet —
FlaxRegNetForImageClassification
(RegNet model) - resnet —
FlaxResNetForImageClassification
(ResNet model) - vit — FlaxViTForImageClassification (ViT model)
Examples:
>>> from transformers import AutoConfig, FlaxAutoModelForImageClassification
>>> # Download model and configuration from huggingface.co and cache.
>>> model = FlaxAutoModelForImageClassification.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = FlaxAutoModelForImageClassification.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = FlaxAutoModelForImageClassification.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
AutoModelForVideoClassification
This is a generic model class that will be instantiated as one of the model classes of the library (with a video classification head) when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- TimesformerConfig configuration class: TimesformerForVideoClassification (TimeSformer model)
VideoMAEConfig
configuration class:VideoMAEForVideoClassification
(VideoMAE model)- VivitConfig configuration class: VivitForVideoClassification (ViViT model)
- attn_implementation (
str
, optional) — The attention implementation to use in the model (if relevant). Can be any of"eager"
(manual implementation of the attention),"sdpa"
(usingF.scaled_dot_product_attention
), or"flash_attention_2"
(using Dao-AILab/flash-attention). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual"eager"
implementation.
Instantiates one of the model classes of the library (with a video classification head) from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the model’s configuration. Use from_pretrained() to load the model weights.
from_pretrained
< source >( *model_args **kwargs )
Parameters
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a tensorflow index checkpoint file (e.g,
./tf_model/model.ckpt.index
). In this case,from_tf
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- state_dict (Dict[str, torch.Tensor], optional) —
A state dictionary to use instead of a state dictionary loaded from saved weights file.
This option can be used if you want to create a model from a pretrained configuration but load your own weights. In this case though, you should check if using save_pretrained() and from_pretrained() is not a simpler option.
- cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - from_tf (
bool
, optional, defaults toFalse
) — Load the model weights from a TensorFlow checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). - force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download — Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. - output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). - revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. - code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the model classes of the library (with a video classification head) from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when it’s missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- timesformer — TimesformerForVideoClassification (TimeSformer model)
- videomae —
VideoMAEForVideoClassification
(VideoMAE model) - vivit — VivitForVideoClassification (ViViT model)
The model is set in evaluation mode by default using model.eval()
(so for instance, dropout modules are
deactivated). To train the model, you should first set it back in training mode with model.train()
Examples:
>>> from transformers import AutoConfig, AutoModelForVideoClassification
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForVideoClassification.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForVideoClassification.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForVideoClassification.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
AutoModelForKeypointDetection
AutoModelForMaskedImageModeling
This is a generic model class that will be instantiated as one of the model classes of the library (with a masked image modeling head) when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
DeiTConfig
configuration class:DeiTForMaskedImageModeling
(DeiT model)FocalNetConfig
configuration class:FocalNetForMaskedImageModeling
(FocalNet model)- SwinConfig configuration class: SwinForMaskedImageModeling (Swin Transformer model)
- Swinv2Config configuration class: Swinv2ForMaskedImageModeling (Swin Transformer V2 model)
- ViTConfig configuration class: ViTForMaskedImageModeling (ViT model)
- attn_implementation (
str
, optional) — The attention implementation to use in the model (if relevant). Can be any of"eager"
(manual implementation of the attention),"sdpa"
(usingF.scaled_dot_product_attention
), or"flash_attention_2"
(using Dao-AILab/flash-attention). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual"eager"
implementation.
Instantiates one of the model classes of the library (with a masked image modeling head) from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the model’s configuration. Use from_pretrained() to load the model weights.
from_pretrained
< source >( *model_args **kwargs )
Parameters
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a tensorflow index checkpoint file (e.g,
./tf_model/model.ckpt.index
). In this case,from_tf
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- state_dict (Dict[str, torch.Tensor], optional) —
A state dictionary to use instead of a state dictionary loaded from saved weights file.
This option can be used if you want to create a model from a pretrained configuration but load your own weights. In this case though, you should check if using save_pretrained() and from_pretrained() is not a simpler option.
- cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - from_tf (
bool
, optional, defaults toFalse
) — Load the model weights from a TensorFlow checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). - force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download — Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. - output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). - revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. - code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the model classes of the library (with a masked image modeling head) from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when it’s missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- deit —
DeiTForMaskedImageModeling
(DeiT model) - focalnet —
FocalNetForMaskedImageModeling
(FocalNet model) - swin — SwinForMaskedImageModeling (Swin Transformer model)
- swinv2 — Swinv2ForMaskedImageModeling (Swin Transformer V2 model)
- vit — ViTForMaskedImageModeling (ViT model)
The model is set in evaluation mode by default using model.eval()
(so for instance, dropout modules are
deactivated). To train the model, you should first set it back in training mode with model.train()
Examples:
>>> from transformers import AutoConfig, AutoModelForMaskedImageModeling
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForMaskedImageModeling.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForMaskedImageModeling.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForMaskedImageModeling.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
TFAutoModelForMaskedImageModeling
This is a generic model class that will be instantiated as one of the model classes of the library (with a masked image modeling head) when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
DeiTConfig
configuration class:TFDeiTForMaskedImageModeling
(DeiT model)- SwinConfig configuration class: TFSwinForMaskedImageModeling (Swin Transformer model)
- attn_implementation (
str
, optional) — The attention implementation to use in the model (if relevant). Can be any of"eager"
(manual implementation of the attention),"sdpa"
(usingF.scaled_dot_product_attention
), or"flash_attention_2"
(using Dao-AILab/flash-attention). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual"eager"
implementation.
Instantiates one of the model classes of the library (with a masked image modeling head) from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the model’s configuration. Use from_pretrained() to load the model weights.
from_pretrained
< source >( *model_args **kwargs )
Parameters
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - from_pt (
bool
, optional, defaults toFalse
) — Load the model weights from a PyTorch checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). - force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download — Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. - output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). - revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. - code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the model classes of the library (with a masked image modeling head) from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when it’s missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- deit —
TFDeiTForMaskedImageModeling
(DeiT model) - swin — TFSwinForMaskedImageModeling (Swin Transformer model)
Examples:
>>> from transformers import AutoConfig, TFAutoModelForMaskedImageModeling
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForMaskedImageModeling.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = TFAutoModelForMaskedImageModeling.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = TFAutoModelForMaskedImageModeling.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
AutoModelForObjectDetection
This is a generic model class that will be instantiated as one of the model classes of the library (with a object detection head) when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
ConditionalDetrConfig
configuration class:ConditionalDetrForObjectDetection
(Conditional DETR model)DeformableDetrConfig
configuration class:DeformableDetrForObjectDetection
(Deformable DETR model)DetaConfig
configuration class:DetaForObjectDetection
(DETA model)DetrConfig
configuration class:DetrForObjectDetection
(DETR model)RTDetrConfig
configuration class:RTDetrForObjectDetection
(RT-DETR model)TableTransformerConfig
configuration class:TableTransformerForObjectDetection
(Table Transformer model)YolosConfig
configuration class:YolosForObjectDetection
(YOLOS model)
- attn_implementation (
str
, optional) — The attention implementation to use in the model (if relevant). Can be any of"eager"
(manual implementation of the attention),"sdpa"
(usingF.scaled_dot_product_attention
), or"flash_attention_2"
(using Dao-AILab/flash-attention). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual"eager"
implementation.
Instantiates one of the model classes of the library (with a object detection head) from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the model’s configuration. Use from_pretrained() to load the model weights.
from_pretrained
< source >( *model_args **kwargs )
Parameters
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a tensorflow index checkpoint file (e.g,
./tf_model/model.ckpt.index
). In this case,from_tf
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- state_dict (Dict[str, torch.Tensor], optional) —
A state dictionary to use instead of a state dictionary loaded from saved weights file.
This option can be used if you want to create a model from a pretrained configuration but load your own weights. In this case though, you should check if using save_pretrained() and from_pretrained() is not a simpler option.
- cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - from_tf (
bool
, optional, defaults toFalse
) — Load the model weights from a TensorFlow checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). - force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download — Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. - output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). - revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. - code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the model classes of the library (with a object detection head) from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when it’s missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- conditional_detr —
ConditionalDetrForObjectDetection
(Conditional DETR model) - deformable_detr —
DeformableDetrForObjectDetection
(Deformable DETR model) - deta —
DetaForObjectDetection
(DETA model) - detr —
DetrForObjectDetection
(DETR model) - rt_detr —
RTDetrForObjectDetection
(RT-DETR model) - table-transformer —
TableTransformerForObjectDetection
(Table Transformer model) - yolos —
YolosForObjectDetection
(YOLOS model)
The model is set in evaluation mode by default using model.eval()
(so for instance, dropout modules are
deactivated). To train the model, you should first set it back in training mode with model.train()
Examples:
>>> from transformers import AutoConfig, AutoModelForObjectDetection
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForObjectDetection.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForObjectDetection.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForObjectDetection.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
AutoModelForImageSegmentation
This is a generic model class that will be instantiated as one of the model classes of the library (with a image segmentation head) when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
DetrConfig
configuration class:DetrForSegmentation
(DETR model)
- attn_implementation (
str
, optional) — The attention implementation to use in the model (if relevant). Can be any of"eager"
(manual implementation of the attention),"sdpa"
(usingF.scaled_dot_product_attention
), or"flash_attention_2"
(using Dao-AILab/flash-attention). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual"eager"
implementation.
Instantiates one of the model classes of the library (with a image segmentation head) from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the model’s configuration. Use from_pretrained() to load the model weights.
from_pretrained
< source >( *model_args **kwargs )
Parameters
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a tensorflow index checkpoint file (e.g,
./tf_model/model.ckpt.index
). In this case,from_tf
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- state_dict (Dict[str, torch.Tensor], optional) —
A state dictionary to use instead of a state dictionary loaded from saved weights file.
This option can be used if you want to create a model from a pretrained configuration but load your own weights. In this case though, you should check if using save_pretrained() and from_pretrained() is not a simpler option.
- cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - from_tf (
bool
, optional, defaults toFalse
) — Load the model weights from a TensorFlow checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). - force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download — Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. - output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). - revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. - code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the model classes of the library (with a image segmentation head) from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when it’s missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- detr —
DetrForSegmentation
(DETR model)
The model is set in evaluation mode by default using model.eval()
(so for instance, dropout modules are
deactivated). To train the model, you should first set it back in training mode with model.train()
Examples:
>>> from transformers import AutoConfig, AutoModelForImageSegmentation
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForImageSegmentation.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForImageSegmentation.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForImageSegmentation.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
AutoModelForImageToImage
AutoModelForSemanticSegmentation
This is a generic model class that will be instantiated as one of the model classes of the library (with a semantic segmentation head) when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
BeitConfig
configuration class:BeitForSemanticSegmentation
(BEiT model)DPTConfig
configuration class:DPTForSemanticSegmentation
(DPT model)Data2VecVisionConfig
configuration class:Data2VecVisionForSemanticSegmentation
(Data2VecVision model)MobileNetV2Config
configuration class:MobileNetV2ForSemanticSegmentation
(MobileNetV2 model)MobileViTConfig
configuration class:MobileViTForSemanticSegmentation
(MobileViT model)MobileViTV2Config
configuration class:MobileViTV2ForSemanticSegmentation
(MobileViTV2 model)SegformerConfig
configuration class:SegformerForSemanticSegmentation
(SegFormer model)UperNetConfig
configuration class:UperNetForSemanticSegmentation
(UPerNet model)
- attn_implementation (
str
, optional) — The attention implementation to use in the model (if relevant). Can be any of"eager"
(manual implementation of the attention),"sdpa"
(usingF.scaled_dot_product_attention
), or"flash_attention_2"
(using Dao-AILab/flash-attention). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual"eager"
implementation.
Instantiates one of the model classes of the library (with a semantic segmentation head) from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the model’s configuration. Use from_pretrained() to load the model weights.
from_pretrained
< source >( *model_args **kwargs )
Parameters
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a tensorflow index checkpoint file (e.g,
./tf_model/model.ckpt.index
). In this case,from_tf
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- state_dict (Dict[str, torch.Tensor], optional) —
A state dictionary to use instead of a state dictionary loaded from saved weights file.
This option can be used if you want to create a model from a pretrained configuration but load your own weights. In this case though, you should check if using save_pretrained() and from_pretrained() is not a simpler option.
- cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - from_tf (
bool
, optional, defaults toFalse
) — Load the model weights from a TensorFlow checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). - force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download — Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. - output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). - revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. - code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the model classes of the library (with a semantic segmentation head) from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when it’s missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- beit —
BeitForSemanticSegmentation
(BEiT model) - data2vec-vision —
Data2VecVisionForSemanticSegmentation
(Data2VecVision model) - dpt —
DPTForSemanticSegmentation
(DPT model) - mobilenet_v2 —
MobileNetV2ForSemanticSegmentation
(MobileNetV2 model) - mobilevit —
MobileViTForSemanticSegmentation
(MobileViT model) - mobilevitv2 —
MobileViTV2ForSemanticSegmentation
(MobileViTV2 model) - segformer —
SegformerForSemanticSegmentation
(SegFormer model) - upernet —
UperNetForSemanticSegmentation
(UPerNet model)
The model is set in evaluation mode by default using model.eval()
(so for instance, dropout modules are
deactivated). To train the model, you should first set it back in training mode with model.train()
Examples:
>>> from transformers import AutoConfig, AutoModelForSemanticSegmentation
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForSemanticSegmentation.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForSemanticSegmentation.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForSemanticSegmentation.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
TFAutoModelForSemanticSegmentation
This is a generic model class that will be instantiated as one of the model classes of the library (with a semantic segmentation head) when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
Data2VecVisionConfig
configuration class:TFData2VecVisionForSemanticSegmentation
(Data2VecVision model)MobileViTConfig
configuration class:TFMobileViTForSemanticSegmentation
(MobileViT model)SegformerConfig
configuration class:TFSegformerForSemanticSegmentation
(SegFormer model)
- attn_implementation (
str
, optional) — The attention implementation to use in the model (if relevant). Can be any of"eager"
(manual implementation of the attention),"sdpa"
(usingF.scaled_dot_product_attention
), or"flash_attention_2"
(using Dao-AILab/flash-attention). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual"eager"
implementation.
Instantiates one of the model classes of the library (with a semantic segmentation head) from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the model’s configuration. Use from_pretrained() to load the model weights.
from_pretrained
< source >( *model_args **kwargs )
Parameters
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - from_pt (
bool
, optional, defaults toFalse
) — Load the model weights from a PyTorch checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). - force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download — Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. - output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). - revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. - code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the model classes of the library (with a semantic segmentation head) from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when it’s missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- data2vec-vision —
TFData2VecVisionForSemanticSegmentation
(Data2VecVision model) - mobilevit —
TFMobileViTForSemanticSegmentation
(MobileViT model) - segformer —
TFSegformerForSemanticSegmentation
(SegFormer model)
Examples:
>>> from transformers import AutoConfig, TFAutoModelForSemanticSegmentation
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForSemanticSegmentation.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = TFAutoModelForSemanticSegmentation.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = TFAutoModelForSemanticSegmentation.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
AutoModelForInstanceSegmentation
This is a generic model class that will be instantiated as one of the model classes of the library (with a instance segmentation head) when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
MaskFormerConfig
configuration class:MaskFormerForInstanceSegmentation
(MaskFormer model)
- attn_implementation (
str
, optional) — The attention implementation to use in the model (if relevant). Can be any of"eager"
(manual implementation of the attention),"sdpa"
(usingF.scaled_dot_product_attention
), or"flash_attention_2"
(using Dao-AILab/flash-attention). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual"eager"
implementation.
Instantiates one of the model classes of the library (with a instance segmentation head) from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the model’s configuration. Use from_pretrained() to load the model weights.
from_pretrained
< source >( *model_args **kwargs )
Parameters
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a tensorflow index checkpoint file (e.g,
./tf_model/model.ckpt.index
). In this case,from_tf
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- state_dict (Dict[str, torch.Tensor], optional) —
A state dictionary to use instead of a state dictionary loaded from saved weights file.
This option can be used if you want to create a model from a pretrained configuration but load your own weights. In this case though, you should check if using save_pretrained() and from_pretrained() is not a simpler option.
- cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - from_tf (
bool
, optional, defaults toFalse
) — Load the model weights from a TensorFlow checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). - force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download — Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. - output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). - revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. - code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the model classes of the library (with a instance segmentation head) from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when it’s missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- maskformer —
MaskFormerForInstanceSegmentation
(MaskFormer model)
The model is set in evaluation mode by default using model.eval()
(so for instance, dropout modules are
deactivated). To train the model, you should first set it back in training mode with model.train()
Examples:
>>> from transformers import AutoConfig, AutoModelForInstanceSegmentation
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForInstanceSegmentation.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForInstanceSegmentation.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForInstanceSegmentation.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
AutoModelForUniversalSegmentation
This is a generic model class that will be instantiated as one of the model classes of the library (with a universal image segmentation head) when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
DetrConfig
configuration class:DetrForSegmentation
(DETR model)Mask2FormerConfig
configuration class:Mask2FormerForUniversalSegmentation
(Mask2Former model)MaskFormerConfig
configuration class:MaskFormerForInstanceSegmentation
(MaskFormer model)OneFormerConfig
configuration class:OneFormerForUniversalSegmentation
(OneFormer model)
- attn_implementation (
str
, optional) — The attention implementation to use in the model (if relevant). Can be any of"eager"
(manual implementation of the attention),"sdpa"
(usingF.scaled_dot_product_attention
), or"flash_attention_2"
(using Dao-AILab/flash-attention). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual"eager"
implementation.
Instantiates one of the model classes of the library (with a universal image segmentation head) from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the model’s configuration. Use from_pretrained() to load the model weights.
from_pretrained
< source >( *model_args **kwargs )
Parameters
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a tensorflow index checkpoint file (e.g,
./tf_model/model.ckpt.index
). In this case,from_tf
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- state_dict (Dict[str, torch.Tensor], optional) —
A state dictionary to use instead of a state dictionary loaded from saved weights file.
This option can be used if you want to create a model from a pretrained configuration but load your own weights. In this case though, you should check if using save_pretrained() and from_pretrained() is not a simpler option.
- cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - from_tf (
bool
, optional, defaults toFalse
) — Load the model weights from a TensorFlow checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). - force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download — Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. - output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). - revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. - code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the model classes of the library (with a universal image segmentation head) from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when it’s missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- detr —
DetrForSegmentation
(DETR model) - mask2former —
Mask2FormerForUniversalSegmentation
(Mask2Former model) - maskformer —
MaskFormerForInstanceSegmentation
(MaskFormer model) - oneformer —
OneFormerForUniversalSegmentation
(OneFormer model)
The model is set in evaluation mode by default using model.eval()
(so for instance, dropout modules are
deactivated). To train the model, you should first set it back in training mode with model.train()
Examples:
>>> from transformers import AutoConfig, AutoModelForUniversalSegmentation
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForUniversalSegmentation.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForUniversalSegmentation.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForUniversalSegmentation.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
AutoModelForZeroShotImageClassification
This is a generic model class that will be instantiated as one of the model classes of the library (with a zero-shot image classification head) when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
AlignConfig
configuration class:AlignModel
(ALIGN model)- AltCLIPConfig configuration class: AltCLIPModel (AltCLIP model)
- Blip2Config configuration class: Blip2ForImageTextRetrieval (BLIP-2 model)
- BlipConfig configuration class: BlipModel (BLIP model)
- CLIPConfig configuration class: CLIPModel (CLIP model)
CLIPSegConfig
configuration class:CLIPSegModel
(CLIPSeg model)ChineseCLIPConfig
configuration class:ChineseCLIPModel
(Chinese-CLIP model)SiglipConfig
configuration class:SiglipModel
(SigLIP model)
- attn_implementation (
str
, optional) — The attention implementation to use in the model (if relevant). Can be any of"eager"
(manual implementation of the attention),"sdpa"
(usingF.scaled_dot_product_attention
), or"flash_attention_2"
(using Dao-AILab/flash-attention). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual"eager"
implementation.
Instantiates one of the model classes of the library (with a zero-shot image classification head) from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the model’s configuration. Use from_pretrained() to load the model weights.
from_pretrained
< source >( *model_args **kwargs )
Parameters
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a tensorflow index checkpoint file (e.g,
./tf_model/model.ckpt.index
). In this case,from_tf
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- state_dict (Dict[str, torch.Tensor], optional) —
A state dictionary to use instead of a state dictionary loaded from saved weights file.
This option can be used if you want to create a model from a pretrained configuration but load your own weights. In this case though, you should check if using save_pretrained() and from_pretrained() is not a simpler option.
- cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - from_tf (
bool
, optional, defaults toFalse
) — Load the model weights from a TensorFlow checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). - force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download — Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. - output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). - revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. - code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the model classes of the library (with a zero-shot image classification head) from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when it’s missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- align —
AlignModel
(ALIGN model) - altclip — AltCLIPModel (AltCLIP model)
- blip — BlipModel (BLIP model)
- blip-2 — Blip2ForImageTextRetrieval (BLIP-2 model)
- chinese_clip —
ChineseCLIPModel
(Chinese-CLIP model) - clip — CLIPModel (CLIP model)
- clipseg —
CLIPSegModel
(CLIPSeg model) - siglip —
SiglipModel
(SigLIP model)
The model is set in evaluation mode by default using model.eval()
(so for instance, dropout modules are
deactivated). To train the model, you should first set it back in training mode with model.train()
Examples:
>>> from transformers import AutoConfig, AutoModelForZeroShotImageClassification
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForZeroShotImageClassification.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForZeroShotImageClassification.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForZeroShotImageClassification.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
TFAutoModelForZeroShotImageClassification
This is a generic model class that will be instantiated as one of the model classes of the library (with a zero-shot image classification head) when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- BlipConfig configuration class: TFBlipModel (BLIP model)
- CLIPConfig configuration class: TFCLIPModel (CLIP model)
- attn_implementation (
str
, optional) — The attention implementation to use in the model (if relevant). Can be any of"eager"
(manual implementation of the attention),"sdpa"
(usingF.scaled_dot_product_attention
), or"flash_attention_2"
(using Dao-AILab/flash-attention). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual"eager"
implementation.
Instantiates one of the model classes of the library (with a zero-shot image classification head) from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the model’s configuration. Use from_pretrained() to load the model weights.
from_pretrained
< source >( *model_args **kwargs )
Parameters
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - from_pt (
bool
, optional, defaults toFalse
) — Load the model weights from a PyTorch checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). - force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download — Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. - output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). - revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. - code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the model classes of the library (with a zero-shot image classification head) from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when it’s missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- blip — TFBlipModel (BLIP model)
- clip — TFCLIPModel (CLIP model)
Examples:
>>> from transformers import AutoConfig, TFAutoModelForZeroShotImageClassification
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForZeroShotImageClassification.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = TFAutoModelForZeroShotImageClassification.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = TFAutoModelForZeroShotImageClassification.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
AutoModelForZeroShotObjectDetection
This is a generic model class that will be instantiated as one of the model classes of the library (with a zero-shot object detection head) when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
GroundingDinoConfig
configuration class:GroundingDinoForObjectDetection
(Grounding DINO model)OmDetTurboConfig
configuration class:OmDetTurboForObjectDetection
(OmDet-Turbo model)OwlViTConfig
configuration class:OwlViTForObjectDetection
(OWL-ViT model)Owlv2Config
configuration class:Owlv2ForObjectDetection
(OWLv2 model)
- attn_implementation (
str
, optional) — The attention implementation to use in the model (if relevant). Can be any of"eager"
(manual implementation of the attention),"sdpa"
(usingF.scaled_dot_product_attention
), or"flash_attention_2"
(using Dao-AILab/flash-attention). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual"eager"
implementation.
Instantiates one of the model classes of the library (with a zero-shot object detection head) from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the model’s configuration. Use from_pretrained() to load the model weights.
from_pretrained
< source >( *model_args **kwargs )
Parameters
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a tensorflow index checkpoint file (e.g,
./tf_model/model.ckpt.index
). In this case,from_tf
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- state_dict (Dict[str, torch.Tensor], optional) —
A state dictionary to use instead of a state dictionary loaded from saved weights file.
This option can be used if you want to create a model from a pretrained configuration but load your own weights. In this case though, you should check if using save_pretrained() and from_pretrained() is not a simpler option.
- cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - from_tf (
bool
, optional, defaults toFalse
) — Load the model weights from a TensorFlow checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). - force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download — Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. - output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). - revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. - code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the model classes of the library (with a zero-shot object detection head) from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when it’s missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- grounding-dino —
GroundingDinoForObjectDetection
(Grounding DINO model) - omdet-turbo —
OmDetTurboForObjectDetection
(OmDet-Turbo model) - owlv2 —
Owlv2ForObjectDetection
(OWLv2 model) - owlvit —
OwlViTForObjectDetection
(OWL-ViT model)
The model is set in evaluation mode by default using model.eval()
(so for instance, dropout modules are
deactivated). To train the model, you should first set it back in training mode with model.train()
Examples:
>>> from transformers import AutoConfig, AutoModelForZeroShotObjectDetection
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForZeroShotObjectDetection.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForZeroShotObjectDetection.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForZeroShotObjectDetection.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
오디오
다음 자동 클래스들은 아래의 오디오 작업에 사용할 수 있습니다.
AutoModelForAudioClassification
This is a generic model class that will be instantiated as one of the model classes of the library (with a audio classification head) when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
ASTConfig
configuration class:ASTForAudioClassification
(Audio Spectrogram Transformer model)Data2VecAudioConfig
configuration class:Data2VecAudioForSequenceClassification
(Data2VecAudio model)HubertConfig
configuration class:HubertForSequenceClassification
(Hubert model)SEWConfig
configuration class:SEWForSequenceClassification
(SEW model)SEWDConfig
configuration class:SEWDForSequenceClassification
(SEW-D model)UniSpeechConfig
configuration class:UniSpeechForSequenceClassification
(UniSpeech model)UniSpeechSatConfig
configuration class:UniSpeechSatForSequenceClassification
(UniSpeechSat model)Wav2Vec2BertConfig
configuration class:Wav2Vec2BertForSequenceClassification
(Wav2Vec2-BERT model)Wav2Vec2Config
configuration class:Wav2Vec2ForSequenceClassification
(Wav2Vec2 model)Wav2Vec2ConformerConfig
configuration class:Wav2Vec2ConformerForSequenceClassification
(Wav2Vec2-Conformer model)WavLMConfig
configuration class:WavLMForSequenceClassification
(WavLM model)- WhisperConfig configuration class: WhisperForAudioClassification (Whisper model)
- attn_implementation (
str
, optional) — The attention implementation to use in the model (if relevant). Can be any of"eager"
(manual implementation of the attention),"sdpa"
(usingF.scaled_dot_product_attention
), or"flash_attention_2"
(using Dao-AILab/flash-attention). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual"eager"
implementation.
Instantiates one of the model classes of the library (with a audio classification head) from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the model’s configuration. Use from_pretrained() to load the model weights.
from_pretrained
< source >( *model_args **kwargs )
Parameters
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a tensorflow index checkpoint file (e.g,
./tf_model/model.ckpt.index
). In this case,from_tf
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- state_dict (Dict[str, torch.Tensor], optional) —
A state dictionary to use instead of a state dictionary loaded from saved weights file.
This option can be used if you want to create a model from a pretrained configuration but load your own weights. In this case though, you should check if using save_pretrained() and from_pretrained() is not a simpler option.
- cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - from_tf (
bool
, optional, defaults toFalse
) — Load the model weights from a TensorFlow checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). - force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download — Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. - output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). - revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. - code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the model classes of the library (with a audio classification head) from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when it’s missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- audio-spectrogram-transformer —
ASTForAudioClassification
(Audio Spectrogram Transformer model) - data2vec-audio —
Data2VecAudioForSequenceClassification
(Data2VecAudio model) - hubert —
HubertForSequenceClassification
(Hubert model) - sew —
SEWForSequenceClassification
(SEW model) - sew-d —
SEWDForSequenceClassification
(SEW-D model) - unispeech —
UniSpeechForSequenceClassification
(UniSpeech model) - unispeech-sat —
UniSpeechSatForSequenceClassification
(UniSpeechSat model) - wav2vec2 —
Wav2Vec2ForSequenceClassification
(Wav2Vec2 model) - wav2vec2-bert —
Wav2Vec2BertForSequenceClassification
(Wav2Vec2-BERT model) - wav2vec2-conformer —
Wav2Vec2ConformerForSequenceClassification
(Wav2Vec2-Conformer model) - wavlm —
WavLMForSequenceClassification
(WavLM model) - whisper — WhisperForAudioClassification (Whisper model)
The model is set in evaluation mode by default using model.eval()
(so for instance, dropout modules are
deactivated). To train the model, you should first set it back in training mode with model.train()
Examples:
>>> from transformers import AutoConfig, AutoModelForAudioClassification
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForAudioClassification.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForAudioClassification.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForAudioClassification.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
TFAutoModelForAudioClassification
This is a generic model class that will be instantiated as one of the model classes of the library (with a audio classification head) when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
Wav2Vec2Config
configuration class:TFWav2Vec2ForSequenceClassification
(Wav2Vec2 model)
- attn_implementation (
str
, optional) — The attention implementation to use in the model (if relevant). Can be any of"eager"
(manual implementation of the attention),"sdpa"
(usingF.scaled_dot_product_attention
), or"flash_attention_2"
(using Dao-AILab/flash-attention). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual"eager"
implementation.
Instantiates one of the model classes of the library (with a audio classification head) from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the model’s configuration. Use from_pretrained() to load the model weights.
from_pretrained
< source >( *model_args **kwargs )
Parameters
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - from_pt (
bool
, optional, defaults toFalse
) — Load the model weights from a PyTorch checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). - force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download — Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. - output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). - revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. - code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the model classes of the library (with a audio classification head) from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when it’s missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- wav2vec2 —
TFWav2Vec2ForSequenceClassification
(Wav2Vec2 model)
Examples:
>>> from transformers import AutoConfig, TFAutoModelForAudioClassification
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForAudioClassification.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = TFAutoModelForAudioClassification.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = TFAutoModelForAudioClassification.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
AutoModelForAudioFrameClassification
This is a generic model class that will be instantiated as one of the model classes of the library (with a audio frame (token) classification head) when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
Data2VecAudioConfig
configuration class:Data2VecAudioForAudioFrameClassification
(Data2VecAudio model)UniSpeechSatConfig
configuration class:UniSpeechSatForAudioFrameClassification
(UniSpeechSat model)Wav2Vec2BertConfig
configuration class:Wav2Vec2BertForAudioFrameClassification
(Wav2Vec2-BERT model)Wav2Vec2Config
configuration class:Wav2Vec2ForAudioFrameClassification
(Wav2Vec2 model)Wav2Vec2ConformerConfig
configuration class:Wav2Vec2ConformerForAudioFrameClassification
(Wav2Vec2-Conformer model)WavLMConfig
configuration class:WavLMForAudioFrameClassification
(WavLM model)
- attn_implementation (
str
, optional) — The attention implementation to use in the model (if relevant). Can be any of"eager"
(manual implementation of the attention),"sdpa"
(usingF.scaled_dot_product_attention
), or"flash_attention_2"
(using Dao-AILab/flash-attention). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual"eager"
implementation.
Instantiates one of the model classes of the library (with a audio frame (token) classification head) from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the model’s configuration. Use from_pretrained() to load the model weights.
from_pretrained
< source >( *model_args **kwargs )
Parameters
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a tensorflow index checkpoint file (e.g,
./tf_model/model.ckpt.index
). In this case,from_tf
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- state_dict (Dict[str, torch.Tensor], optional) —
A state dictionary to use instead of a state dictionary loaded from saved weights file.
This option can be used if you want to create a model from a pretrained configuration but load your own weights. In this case though, you should check if using save_pretrained() and from_pretrained() is not a simpler option.
- cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - from_tf (
bool
, optional, defaults toFalse
) — Load the model weights from a TensorFlow checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). - force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download — Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. - output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). - revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. - code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the model classes of the library (with a audio frame (token) classification head) from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when it’s missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- data2vec-audio —
Data2VecAudioForAudioFrameClassification
(Data2VecAudio model) - unispeech-sat —
UniSpeechSatForAudioFrameClassification
(UniSpeechSat model) - wav2vec2 —
Wav2Vec2ForAudioFrameClassification
(Wav2Vec2 model) - wav2vec2-bert —
Wav2Vec2BertForAudioFrameClassification
(Wav2Vec2-BERT model) - wav2vec2-conformer —
Wav2Vec2ConformerForAudioFrameClassification
(Wav2Vec2-Conformer model) - wavlm —
WavLMForAudioFrameClassification
(WavLM model)
The model is set in evaluation mode by default using model.eval()
(so for instance, dropout modules are
deactivated). To train the model, you should first set it back in training mode with model.train()
Examples:
>>> from transformers import AutoConfig, AutoModelForAudioFrameClassification
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForAudioFrameClassification.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForAudioFrameClassification.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForAudioFrameClassification.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
AutoModelForCTC
This is a generic model class that will be instantiated as one of the model classes of the library (with a connectionist temporal classification head) when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
Data2VecAudioConfig
configuration class:Data2VecAudioForCTC
(Data2VecAudio model)HubertConfig
configuration class:HubertForCTC
(Hubert model)MCTCTConfig
configuration class:MCTCTForCTC
(M-CTC-T model)SEWConfig
configuration class:SEWForCTC
(SEW model)SEWDConfig
configuration class:SEWDForCTC
(SEW-D model)UniSpeechConfig
configuration class:UniSpeechForCTC
(UniSpeech model)UniSpeechSatConfig
configuration class:UniSpeechSatForCTC
(UniSpeechSat model)Wav2Vec2BertConfig
configuration class:Wav2Vec2BertForCTC
(Wav2Vec2-BERT model)Wav2Vec2Config
configuration class:Wav2Vec2ForCTC
(Wav2Vec2 model)Wav2Vec2ConformerConfig
configuration class:Wav2Vec2ConformerForCTC
(Wav2Vec2-Conformer model)WavLMConfig
configuration class:WavLMForCTC
(WavLM model)
- attn_implementation (
str
, optional) — The attention implementation to use in the model (if relevant). Can be any of"eager"
(manual implementation of the attention),"sdpa"
(usingF.scaled_dot_product_attention
), or"flash_attention_2"
(using Dao-AILab/flash-attention). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual"eager"
implementation.
Instantiates one of the model classes of the library (with a connectionist temporal classification head) from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the model’s configuration. Use from_pretrained() to load the model weights.
from_pretrained
< source >( *model_args **kwargs )
Parameters
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a tensorflow index checkpoint file (e.g,
./tf_model/model.ckpt.index
). In this case,from_tf
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- state_dict (Dict[str, torch.Tensor], optional) —
A state dictionary to use instead of a state dictionary loaded from saved weights file.
This option can be used if you want to create a model from a pretrained configuration but load your own weights. In this case though, you should check if using save_pretrained() and from_pretrained() is not a simpler option.
- cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - from_tf (
bool
, optional, defaults toFalse
) — Load the model weights from a TensorFlow checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). - force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download — Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. - output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). - revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. - code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the model classes of the library (with a connectionist temporal classification head) from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when it’s missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- data2vec-audio —
Data2VecAudioForCTC
(Data2VecAudio model) - hubert —
HubertForCTC
(Hubert model) - mctct —
MCTCTForCTC
(M-CTC-T model) - sew —
SEWForCTC
(SEW model) - sew-d —
SEWDForCTC
(SEW-D model) - unispeech —
UniSpeechForCTC
(UniSpeech model) - unispeech-sat —
UniSpeechSatForCTC
(UniSpeechSat model) - wav2vec2 —
Wav2Vec2ForCTC
(Wav2Vec2 model) - wav2vec2-bert —
Wav2Vec2BertForCTC
(Wav2Vec2-BERT model) - wav2vec2-conformer —
Wav2Vec2ConformerForCTC
(Wav2Vec2-Conformer model) - wavlm —
WavLMForCTC
(WavLM model)
The model is set in evaluation mode by default using model.eval()
(so for instance, dropout modules are
deactivated). To train the model, you should first set it back in training mode with model.train()
Examples:
>>> from transformers import AutoConfig, AutoModelForCTC
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForCTC.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForCTC.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForCTC.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
AutoModelForSpeechSeq2Seq
This is a generic model class that will be instantiated as one of the model classes of the library (with a sequence-to-sequence speech-to-text modeling head) when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
MoonshineConfig
configuration class:MoonshineForConditionalGeneration
(Moonshine model)Pop2PianoConfig
configuration class:Pop2PianoForConditionalGeneration
(Pop2Piano model)SeamlessM4TConfig
configuration class:SeamlessM4TForSpeechToText
(SeamlessM4T model)SeamlessM4Tv2Config
configuration class:SeamlessM4Tv2ForSpeechToText
(SeamlessM4Tv2 model)Speech2TextConfig
configuration class:Speech2TextForConditionalGeneration
(Speech2Text model)SpeechEncoderDecoderConfig
configuration class:SpeechEncoderDecoderModel
(Speech Encoder decoder model)SpeechT5Config
configuration class:SpeechT5ForSpeechToText
(SpeechT5 model)- WhisperConfig configuration class: WhisperForConditionalGeneration (Whisper model)
- attn_implementation (
str
, optional) — The attention implementation to use in the model (if relevant). Can be any of"eager"
(manual implementation of the attention),"sdpa"
(usingF.scaled_dot_product_attention
), or"flash_attention_2"
(using Dao-AILab/flash-attention). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual"eager"
implementation.
Instantiates one of the model classes of the library (with a sequence-to-sequence speech-to-text modeling head) from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the model’s configuration. Use from_pretrained() to load the model weights.
from_pretrained
< source >( *model_args **kwargs )
Parameters
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a tensorflow index checkpoint file (e.g,
./tf_model/model.ckpt.index
). In this case,from_tf
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- state_dict (Dict[str, torch.Tensor], optional) —
A state dictionary to use instead of a state dictionary loaded from saved weights file.
This option can be used if you want to create a model from a pretrained configuration but load your own weights. In this case though, you should check if using save_pretrained() and from_pretrained() is not a simpler option.
- cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - from_tf (
bool
, optional, defaults toFalse
) — Load the model weights from a TensorFlow checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). - force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download — Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. - output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). - revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. - code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the model classes of the library (with a sequence-to-sequence speech-to-text modeling head) from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when it’s missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- moonshine —
MoonshineForConditionalGeneration
(Moonshine model) - pop2piano —
Pop2PianoForConditionalGeneration
(Pop2Piano model) - seamless_m4t —
SeamlessM4TForSpeechToText
(SeamlessM4T model) - seamless_m4t_v2 —
SeamlessM4Tv2ForSpeechToText
(SeamlessM4Tv2 model) - speech-encoder-decoder —
SpeechEncoderDecoderModel
(Speech Encoder decoder model) - speech_to_text —
Speech2TextForConditionalGeneration
(Speech2Text model) - speecht5 —
SpeechT5ForSpeechToText
(SpeechT5 model) - whisper — WhisperForConditionalGeneration (Whisper model)
The model is set in evaluation mode by default using model.eval()
(so for instance, dropout modules are
deactivated). To train the model, you should first set it back in training mode with model.train()
Examples:
>>> from transformers import AutoConfig, AutoModelForSpeechSeq2Seq
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForSpeechSeq2Seq.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForSpeechSeq2Seq.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForSpeechSeq2Seq.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
TFAutoModelForSpeechSeq2Seq
This is a generic model class that will be instantiated as one of the model classes of the library (with a sequence-to-sequence speech-to-text modeling head) when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
Speech2TextConfig
configuration class:TFSpeech2TextForConditionalGeneration
(Speech2Text model)- WhisperConfig configuration class: TFWhisperForConditionalGeneration (Whisper model)
- attn_implementation (
str
, optional) — The attention implementation to use in the model (if relevant). Can be any of"eager"
(manual implementation of the attention),"sdpa"
(usingF.scaled_dot_product_attention
), or"flash_attention_2"
(using Dao-AILab/flash-attention). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual"eager"
implementation.
Instantiates one of the model classes of the library (with a sequence-to-sequence speech-to-text modeling head) from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the model’s configuration. Use from_pretrained() to load the model weights.
from_pretrained
< source >( *model_args **kwargs )
Parameters
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - from_pt (
bool
, optional, defaults toFalse
) — Load the model weights from a PyTorch checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). - force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download — Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. - output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). - revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. - code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the model classes of the library (with a sequence-to-sequence speech-to-text modeling head) from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when it’s missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- speech_to_text —
TFSpeech2TextForConditionalGeneration
(Speech2Text model) - whisper — TFWhisperForConditionalGeneration (Whisper model)
Examples:
>>> from transformers import AutoConfig, TFAutoModelForSpeechSeq2Seq
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForSpeechSeq2Seq.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = TFAutoModelForSpeechSeq2Seq.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = TFAutoModelForSpeechSeq2Seq.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
FlaxAutoModelForSpeechSeq2Seq
This is a generic model class that will be instantiated as one of the model classes of the library (with a sequence-to-sequence speech-to-text modeling head) when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
SpeechEncoderDecoderConfig
configuration class:FlaxSpeechEncoderDecoderModel
(Speech Encoder decoder model)- WhisperConfig configuration class: FlaxWhisperForConditionalGeneration (Whisper model)
- attn_implementation (
str
, optional) — The attention implementation to use in the model (if relevant). Can be any of"eager"
(manual implementation of the attention),"sdpa"
(usingF.scaled_dot_product_attention
), or"flash_attention_2"
(using Dao-AILab/flash-attention). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual"eager"
implementation.
Instantiates one of the model classes of the library (with a sequence-to-sequence speech-to-text modeling head) from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the model’s configuration. Use from_pretrained() to load the model weights.
from_pretrained
< source >( *model_args **kwargs )
Parameters
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - from_pt (
bool
, optional, defaults toFalse
) — Load the model weights from a PyTorch checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). - force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download — Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. - output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). - revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. - code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the model classes of the library (with a sequence-to-sequence speech-to-text modeling head) from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when it’s missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- speech-encoder-decoder —
FlaxSpeechEncoderDecoderModel
(Speech Encoder decoder model) - whisper — FlaxWhisperForConditionalGeneration (Whisper model)
Examples:
>>> from transformers import AutoConfig, FlaxAutoModelForSpeechSeq2Seq
>>> # Download model and configuration from huggingface.co and cache.
>>> model = FlaxAutoModelForSpeechSeq2Seq.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = FlaxAutoModelForSpeechSeq2Seq.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = FlaxAutoModelForSpeechSeq2Seq.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
AutoModelForAudioXVector
This is a generic model class that will be instantiated as one of the model classes of the library (with a audio retrieval via x-vector head) when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
Data2VecAudioConfig
configuration class:Data2VecAudioForXVector
(Data2VecAudio model)UniSpeechSatConfig
configuration class:UniSpeechSatForXVector
(UniSpeechSat model)Wav2Vec2BertConfig
configuration class:Wav2Vec2BertForXVector
(Wav2Vec2-BERT model)Wav2Vec2Config
configuration class:Wav2Vec2ForXVector
(Wav2Vec2 model)Wav2Vec2ConformerConfig
configuration class:Wav2Vec2ConformerForXVector
(Wav2Vec2-Conformer model)WavLMConfig
configuration class:WavLMForXVector
(WavLM model)
- attn_implementation (
str
, optional) — The attention implementation to use in the model (if relevant). Can be any of"eager"
(manual implementation of the attention),"sdpa"
(usingF.scaled_dot_product_attention
), or"flash_attention_2"
(using Dao-AILab/flash-attention). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual"eager"
implementation.
Instantiates one of the model classes of the library (with a audio retrieval via x-vector head) from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the model’s configuration. Use from_pretrained() to load the model weights.
from_pretrained
< source >( *model_args **kwargs )
Parameters
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a tensorflow index checkpoint file (e.g,
./tf_model/model.ckpt.index
). In this case,from_tf
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- state_dict (Dict[str, torch.Tensor], optional) —
A state dictionary to use instead of a state dictionary loaded from saved weights file.
This option can be used if you want to create a model from a pretrained configuration but load your own weights. In this case though, you should check if using save_pretrained() and from_pretrained() is not a simpler option.
- cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - from_tf (
bool
, optional, defaults toFalse
) — Load the model weights from a TensorFlow checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). - force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download — Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. - output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). - revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. - code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the model classes of the library (with a audio retrieval via x-vector head) from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when it’s missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- data2vec-audio —
Data2VecAudioForXVector
(Data2VecAudio model) - unispeech-sat —
UniSpeechSatForXVector
(UniSpeechSat model) - wav2vec2 —
Wav2Vec2ForXVector
(Wav2Vec2 model) - wav2vec2-bert —
Wav2Vec2BertForXVector
(Wav2Vec2-BERT model) - wav2vec2-conformer —
Wav2Vec2ConformerForXVector
(Wav2Vec2-Conformer model) - wavlm —
WavLMForXVector
(WavLM model)
The model is set in evaluation mode by default using model.eval()
(so for instance, dropout modules are
deactivated). To train the model, you should first set it back in training mode with model.train()
Examples:
>>> from transformers import AutoConfig, AutoModelForAudioXVector
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForAudioXVector.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForAudioXVector.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForAudioXVector.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
AutoModelForTextToSpectrogram
AutoModelForTextToWaveform
멀티모달
다음 자동 클래스들은 아래의 멀티모달 작업에 사용할 수 있습니다.
AutoModelForTableQuestionAnswering
This is a generic model class that will be instantiated as one of the model classes of the library (with a table question answering head) when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
TapasConfig
configuration class:TapasForQuestionAnswering
(TAPAS model)
- attn_implementation (
str
, optional) — The attention implementation to use in the model (if relevant). Can be any of"eager"
(manual implementation of the attention),"sdpa"
(usingF.scaled_dot_product_attention
), or"flash_attention_2"
(using Dao-AILab/flash-attention). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual"eager"
implementation.
Instantiates one of the model classes of the library (with a table question answering head) from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the model’s configuration. Use from_pretrained() to load the model weights.
from_pretrained
< source >( *model_args **kwargs )
Parameters
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a tensorflow index checkpoint file (e.g,
./tf_model/model.ckpt.index
). In this case,from_tf
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- state_dict (Dict[str, torch.Tensor], optional) —
A state dictionary to use instead of a state dictionary loaded from saved weights file.
This option can be used if you want to create a model from a pretrained configuration but load your own weights. In this case though, you should check if using save_pretrained() and from_pretrained() is not a simpler option.
- cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - from_tf (
bool
, optional, defaults toFalse
) — Load the model weights from a TensorFlow checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). - force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download — Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. - output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). - revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. - code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the model classes of the library (with a table question answering head) from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when it’s missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- tapas —
TapasForQuestionAnswering
(TAPAS model)
The model is set in evaluation mode by default using model.eval()
(so for instance, dropout modules are
deactivated). To train the model, you should first set it back in training mode with model.train()
Examples:
>>> from transformers import AutoConfig, AutoModelForTableQuestionAnswering
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForTableQuestionAnswering.from_pretrained("google/tapas-base-finetuned-wtq")
>>> # Update configuration during loading
>>> model = AutoModelForTableQuestionAnswering.from_pretrained("google/tapas-base-finetuned-wtq", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/tapas_tf_model_config.json")
>>> model = AutoModelForTableQuestionAnswering.from_pretrained(
... "./tf_model/tapas_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
TFAutoModelForTableQuestionAnswering
This is a generic model class that will be instantiated as one of the model classes of the library (with a table question answering head) when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
TapasConfig
configuration class:TFTapasForQuestionAnswering
(TAPAS model)
- attn_implementation (
str
, optional) — The attention implementation to use in the model (if relevant). Can be any of"eager"
(manual implementation of the attention),"sdpa"
(usingF.scaled_dot_product_attention
), or"flash_attention_2"
(using Dao-AILab/flash-attention). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual"eager"
implementation.
Instantiates one of the model classes of the library (with a table question answering head) from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the model’s configuration. Use from_pretrained() to load the model weights.
from_pretrained
< source >( *model_args **kwargs )
Parameters
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - from_pt (
bool
, optional, defaults toFalse
) — Load the model weights from a PyTorch checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). - force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download — Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. - output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). - revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. - code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the model classes of the library (with a table question answering head) from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when it’s missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- tapas —
TFTapasForQuestionAnswering
(TAPAS model)
Examples:
>>> from transformers import AutoConfig, TFAutoModelForTableQuestionAnswering
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForTableQuestionAnswering.from_pretrained("google/tapas-base-finetuned-wtq")
>>> # Update configuration during loading
>>> model = TFAutoModelForTableQuestionAnswering.from_pretrained("google/tapas-base-finetuned-wtq", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/tapas_pt_model_config.json")
>>> model = TFAutoModelForTableQuestionAnswering.from_pretrained(
... "./pt_model/tapas_pytorch_model.bin", from_pt=True, config=config
... )
AutoModelForDocumentQuestionAnswering
This is a generic model class that will be instantiated as one of the model classes of the library (with a document question answering head) when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
LayoutLMConfig
configuration class:LayoutLMForQuestionAnswering
(LayoutLM model)LayoutLMv2Config
configuration class:LayoutLMv2ForQuestionAnswering
(LayoutLMv2 model)LayoutLMv3Config
configuration class:LayoutLMv3ForQuestionAnswering
(LayoutLMv3 model)
- attn_implementation (
str
, optional) — The attention implementation to use in the model (if relevant). Can be any of"eager"
(manual implementation of the attention),"sdpa"
(usingF.scaled_dot_product_attention
), or"flash_attention_2"
(using Dao-AILab/flash-attention). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual"eager"
implementation.
Instantiates one of the model classes of the library (with a document question answering head) from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the model’s configuration. Use from_pretrained() to load the model weights.
Examples:
>>> from transformers import AutoConfig, AutoModelForDocumentQuestionAnswering
>>> # Download configuration from huggingface.co and cache.
>>> config = AutoConfig.from_pretrained("impira/layoutlm-document-qa", revision="52e01b3")
>>> model = AutoModelForDocumentQuestionAnswering.from_config(config)
from_pretrained
< source >( *model_args **kwargs )
Parameters
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a tensorflow index checkpoint file (e.g,
./tf_model/model.ckpt.index
). In this case,from_tf
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- state_dict (Dict[str, torch.Tensor], optional) —
A state dictionary to use instead of a state dictionary loaded from saved weights file.
This option can be used if you want to create a model from a pretrained configuration but load your own weights. In this case though, you should check if using save_pretrained() and from_pretrained() is not a simpler option.
- cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - from_tf (
bool
, optional, defaults toFalse
) — Load the model weights from a TensorFlow checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). - force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download — Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. - output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). - revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. - code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the model classes of the library (with a document question answering head) from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when it’s missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- layoutlm —
LayoutLMForQuestionAnswering
(LayoutLM model) - layoutlmv2 —
LayoutLMv2ForQuestionAnswering
(LayoutLMv2 model) - layoutlmv3 —
LayoutLMv3ForQuestionAnswering
(LayoutLMv3 model)
The model is set in evaluation mode by default using model.eval()
(so for instance, dropout modules are
deactivated). To train the model, you should first set it back in training mode with model.train()
Examples:
>>> from transformers import AutoConfig, AutoModelForDocumentQuestionAnswering
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForDocumentQuestionAnswering.from_pretrained("impira/layoutlm-document-qa", revision="52e01b3")
>>> # Update configuration during loading
>>> model = AutoModelForDocumentQuestionAnswering.from_pretrained("impira/layoutlm-document-qa", revision="52e01b3", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/layoutlm_tf_model_config.json")
>>> model = AutoModelForDocumentQuestionAnswering.from_pretrained(
... "./tf_model/layoutlm_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
TFAutoModelForDocumentQuestionAnswering
This is a generic model class that will be instantiated as one of the model classes of the library (with a document question answering head) when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
LayoutLMConfig
configuration class:TFLayoutLMForQuestionAnswering
(LayoutLM model)LayoutLMv3Config
configuration class:TFLayoutLMv3ForQuestionAnswering
(LayoutLMv3 model)
- attn_implementation (
str
, optional) — The attention implementation to use in the model (if relevant). Can be any of"eager"
(manual implementation of the attention),"sdpa"
(usingF.scaled_dot_product_attention
), or"flash_attention_2"
(using Dao-AILab/flash-attention). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual"eager"
implementation.
Instantiates one of the model classes of the library (with a document question answering head) from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the model’s configuration. Use from_pretrained() to load the model weights.
Examples:
>>> from transformers import AutoConfig, TFAutoModelForDocumentQuestionAnswering
>>> # Download configuration from huggingface.co and cache.
>>> config = AutoConfig.from_pretrained("impira/layoutlm-document-qa", revision="52e01b3")
>>> model = TFAutoModelForDocumentQuestionAnswering.from_config(config)
from_pretrained
< source >( *model_args **kwargs )
Parameters
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - from_pt (
bool
, optional, defaults toFalse
) — Load the model weights from a PyTorch checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). - force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download — Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. - output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). - revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. - code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the model classes of the library (with a document question answering head) from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when it’s missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- layoutlm —
TFLayoutLMForQuestionAnswering
(LayoutLM model) - layoutlmv3 —
TFLayoutLMv3ForQuestionAnswering
(LayoutLMv3 model)
Examples:
>>> from transformers import AutoConfig, TFAutoModelForDocumentQuestionAnswering
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForDocumentQuestionAnswering.from_pretrained("impira/layoutlm-document-qa", revision="52e01b3")
>>> # Update configuration during loading
>>> model = TFAutoModelForDocumentQuestionAnswering.from_pretrained("impira/layoutlm-document-qa", revision="52e01b3", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/layoutlm_pt_model_config.json")
>>> model = TFAutoModelForDocumentQuestionAnswering.from_pretrained(
... "./pt_model/layoutlm_pytorch_model.bin", from_pt=True, config=config
... )
AutoModelForVisualQuestionAnswering
This is a generic model class that will be instantiated as one of the model classes of the library (with a visual question answering head) when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- Blip2Config configuration class: Blip2ForConditionalGeneration (BLIP-2 model)
- BlipConfig configuration class: BlipForQuestionAnswering (BLIP model)
ViltConfig
configuration class:ViltForQuestionAnswering
(ViLT model)
- attn_implementation (
str
, optional) — The attention implementation to use in the model (if relevant). Can be any of"eager"
(manual implementation of the attention),"sdpa"
(usingF.scaled_dot_product_attention
), or"flash_attention_2"
(using Dao-AILab/flash-attention). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual"eager"
implementation.
Instantiates one of the model classes of the library (with a visual question answering head) from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the model’s configuration. Use from_pretrained() to load the model weights.
from_pretrained
< source >( *model_args **kwargs )
Parameters
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a tensorflow index checkpoint file (e.g,
./tf_model/model.ckpt.index
). In this case,from_tf
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- state_dict (Dict[str, torch.Tensor], optional) —
A state dictionary to use instead of a state dictionary loaded from saved weights file.
This option can be used if you want to create a model from a pretrained configuration but load your own weights. In this case though, you should check if using save_pretrained() and from_pretrained() is not a simpler option.
- cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - from_tf (
bool
, optional, defaults toFalse
) — Load the model weights from a TensorFlow checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). - force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download — Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. - output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). - revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. - code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the model classes of the library (with a visual question answering head) from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when it’s missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- blip — BlipForQuestionAnswering (BLIP model)
- blip-2 — Blip2ForConditionalGeneration (BLIP-2 model)
- vilt —
ViltForQuestionAnswering
(ViLT model)
The model is set in evaluation mode by default using model.eval()
(so for instance, dropout modules are
deactivated). To train the model, you should first set it back in training mode with model.train()
Examples:
>>> from transformers import AutoConfig, AutoModelForVisualQuestionAnswering
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForVisualQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
>>> # Update configuration during loading
>>> model = AutoModelForVisualQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/vilt_tf_model_config.json")
>>> model = AutoModelForVisualQuestionAnswering.from_pretrained(
... "./tf_model/vilt_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
AutoModelForVision2Seq
This is a generic model class that will be instantiated as one of the model classes of the library (with a vision-to-text modeling head) when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- Blip2Config configuration class: Blip2ForConditionalGeneration (BLIP-2 model)
- BlipConfig configuration class: BlipForConditionalGeneration (BLIP model)
- ChameleonConfig configuration class: ChameleonForConditionalGeneration (Chameleon model)
GitConfig
configuration class:GitForCausalLM
(GIT model)Idefics2Config
configuration class:Idefics2ForConditionalGeneration
(Idefics2 model)Idefics3Config
configuration class:Idefics3ForConditionalGeneration
(Idefics3 model)InstructBlipConfig
configuration class:InstructBlipForConditionalGeneration
(InstructBLIP model)InstructBlipVideoConfig
configuration class:InstructBlipVideoForConditionalGeneration
(InstructBlipVideo model)Kosmos2Config
configuration class:Kosmos2ForConditionalGeneration
(KOSMOS-2 model)LlavaConfig
configuration class:LlavaForConditionalGeneration
(LLaVa model)LlavaNextConfig
configuration class:LlavaNextForConditionalGeneration
(LLaVA-NeXT model)LlavaNextVideoConfig
configuration class:LlavaNextVideoForConditionalGeneration
(LLaVa-NeXT-Video model)LlavaOnevisionConfig
configuration class:LlavaOnevisionForConditionalGeneration
(LLaVA-Onevision model)MllamaConfig
configuration class:MllamaForConditionalGeneration
(Mllama model)- PaliGemmaConfig configuration class: PaliGemmaForConditionalGeneration (PaliGemma model)
Pix2StructConfig
configuration class:Pix2StructForConditionalGeneration
(Pix2Struct model)Qwen2VLConfig
configuration class:Qwen2VLForConditionalGeneration
(Qwen2VL model)VideoLlavaConfig
configuration class:VideoLlavaForConditionalGeneration
(VideoLlava model)VipLlavaConfig
configuration class:VipLlavaForConditionalGeneration
(VipLlava model)VisionEncoderDecoderConfig
configuration class:VisionEncoderDecoderModel
(Vision Encoder decoder model)
- attn_implementation (
str
, optional) — The attention implementation to use in the model (if relevant). Can be any of"eager"
(manual implementation of the attention),"sdpa"
(usingF.scaled_dot_product_attention
), or"flash_attention_2"
(using Dao-AILab/flash-attention). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual"eager"
implementation.
Instantiates one of the model classes of the library (with a vision-to-text modeling head) from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the model’s configuration. Use from_pretrained() to load the model weights.
from_pretrained
< source >( *model_args **kwargs )
Parameters
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a tensorflow index checkpoint file (e.g,
./tf_model/model.ckpt.index
). In this case,from_tf
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- state_dict (Dict[str, torch.Tensor], optional) —
A state dictionary to use instead of a state dictionary loaded from saved weights file.
This option can be used if you want to create a model from a pretrained configuration but load your own weights. In this case though, you should check if using save_pretrained() and from_pretrained() is not a simpler option.
- cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - from_tf (
bool
, optional, defaults toFalse
) — Load the model weights from a TensorFlow checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). - force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download — Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. - output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). - revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. - code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the model classes of the library (with a vision-to-text modeling head) from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when it’s missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- blip — BlipForConditionalGeneration (BLIP model)
- blip-2 — Blip2ForConditionalGeneration (BLIP-2 model)
- chameleon — ChameleonForConditionalGeneration (Chameleon model)
- git —
GitForCausalLM
(GIT model) - idefics2 —
Idefics2ForConditionalGeneration
(Idefics2 model) - idefics3 —
Idefics3ForConditionalGeneration
(Idefics3 model) - instructblip —
InstructBlipForConditionalGeneration
(InstructBLIP model) - instructblipvideo —
InstructBlipVideoForConditionalGeneration
(InstructBlipVideo model) - kosmos-2 —
Kosmos2ForConditionalGeneration
(KOSMOS-2 model) - llava —
LlavaForConditionalGeneration
(LLaVa model) - llava_next —
LlavaNextForConditionalGeneration
(LLaVA-NeXT model) - llava_next_video —
LlavaNextVideoForConditionalGeneration
(LLaVa-NeXT-Video model) - llava_onevision —
LlavaOnevisionForConditionalGeneration
(LLaVA-Onevision model) - mllama —
MllamaForConditionalGeneration
(Mllama model) - paligemma — PaliGemmaForConditionalGeneration (PaliGemma model)
- pix2struct —
Pix2StructForConditionalGeneration
(Pix2Struct model) - qwen2_vl —
Qwen2VLForConditionalGeneration
(Qwen2VL model) - video_llava —
VideoLlavaForConditionalGeneration
(VideoLlava model) - vipllava —
VipLlavaForConditionalGeneration
(VipLlava model) - vision-encoder-decoder —
VisionEncoderDecoderModel
(Vision Encoder decoder model)
The model is set in evaluation mode by default using model.eval()
(so for instance, dropout modules are
deactivated). To train the model, you should first set it back in training mode with model.train()
Examples:
>>> from transformers import AutoConfig, AutoModelForVision2Seq
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForVision2Seq.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForVision2Seq.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForVision2Seq.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
TFAutoModelForVision2Seq
This is a generic model class that will be instantiated as one of the model classes of the library (with a vision-to-text modeling head) when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- BlipConfig configuration class: TFBlipForConditionalGeneration (BLIP model)
VisionEncoderDecoderConfig
configuration class:TFVisionEncoderDecoderModel
(Vision Encoder decoder model)
- attn_implementation (
str
, optional) — The attention implementation to use in the model (if relevant). Can be any of"eager"
(manual implementation of the attention),"sdpa"
(usingF.scaled_dot_product_attention
), or"flash_attention_2"
(using Dao-AILab/flash-attention). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual"eager"
implementation.
Instantiates one of the model classes of the library (with a vision-to-text modeling head) from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the model’s configuration. Use from_pretrained() to load the model weights.
from_pretrained
< source >( *model_args **kwargs )
Parameters
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - from_pt (
bool
, optional, defaults toFalse
) — Load the model weights from a PyTorch checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). - force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download — Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. - output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). - revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. - code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the model classes of the library (with a vision-to-text modeling head) from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when it’s missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- blip — TFBlipForConditionalGeneration (BLIP model)
- vision-encoder-decoder —
TFVisionEncoderDecoderModel
(Vision Encoder decoder model)
Examples:
>>> from transformers import AutoConfig, TFAutoModelForVision2Seq
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForVision2Seq.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = TFAutoModelForVision2Seq.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = TFAutoModelForVision2Seq.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
FlaxAutoModelForVision2Seq
This is a generic model class that will be instantiated as one of the model classes of the library (with a vision-to-text modeling head) when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
- config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
VisionEncoderDecoderConfig
configuration class:FlaxVisionEncoderDecoderModel
(Vision Encoder decoder model)
- attn_implementation (
str
, optional) — The attention implementation to use in the model (if relevant). Can be any of"eager"
(manual implementation of the attention),"sdpa"
(usingF.scaled_dot_product_attention
), or"flash_attention_2"
(using Dao-AILab/flash-attention). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual"eager"
implementation.
Instantiates one of the model classes of the library (with a vision-to-text modeling head) from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the model’s configuration. Use from_pretrained() to load the model weights.
from_pretrained
< source >( *model_args **kwargs )
Parameters
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. - config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - from_pt (
bool
, optional, defaults toFalse
) — Load the model weights from a PyTorch checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). - force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download — Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. - output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). - revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. - code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the model classes of the library (with a vision-to-text modeling head) from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when it’s missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- vision-encoder-decoder —
FlaxVisionEncoderDecoderModel
(Vision Encoder decoder model)
Examples:
>>> from transformers import AutoConfig, FlaxAutoModelForVision2Seq
>>> # Download model and configuration from huggingface.co and cache.
>>> model = FlaxAutoModelForVision2Seq.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = FlaxAutoModelForVision2Seq.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = FlaxAutoModelForVision2Seq.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )