erfanzar/Marin-8B-Instruct-eformat
A model implemented using the EasyDeL framework, designed to deliver optimal performance for large-scale natural language processing tasks.
Overview
This model is built using EasyDeL, an open-source framework designed to enhance and streamline the training and serving process of machine learning models, with a primary focus on Jax/Flax on TPU/GPU at scale.
EasyDeL provides an efficient, highly-optimized, and customizable machine learning model compatible with both GPU and TPU environments. Built with JAX, this model supports advanced features such as sharded model parallelism, making it suitable for distributed training and inference and customized kernels.
Features Provided by EasyDeL
EasyDeL Framework Features:
- Efficient Implementation: Built with JAX/Flax for high-performance computation.
- Modern Architecture: Built on Flax NNX for better integration, modularity, and performance.
- Multi-Device Support: Optimized to run on TPU, GPU, and CPU environments.
- Sharded Model Parallelism: Supports model parallelism across multiple devices for scalability (using
auto_shard_model=True
). - Customizable Precision: Allows specification of
dtype
,param_dtype
, andprecision
. - Advanced Serving: Includes
vInference
engine and OpenAI-compatible API server. - Optimized Kernels: Integrates multiple attention mechanisms (like
AttentionMechanisms.SPLASH
) and platform-specific optimizations.
Installation
To use this model via EasyDeL, first install EasyDeL:
pip install easydel
Usage
Loading the Pre-trained Model
To load this pre-trained model with EasyDeL:
from easydel import AutoEasyDeLModelForCausalLM, EasyDeLBaseConfigDict, AttentionMechanisms
from jax import numpy as jnp, lax
# Define max_length if needed for memory optimization
max_length = None
# Load model and parameters
# Set auto_shard_model=True to automatically distribute across devices
model = AutoEasyDeLModelForCausalLM.from_pretrained(
"erfanzar/Marin-8B-Instruct-eformat",
config_kwargs=EasyDeLBaseConfigDict(
# use_scan_mlp=False, # Set to True to potentially reduce memory usage
attn_dtype=jnp.float16, # Or jnp.bfloat16
# freq_max_position_embeddings=max_length, # Set if using RoPE and need truncation
# mask_max_position_embeddings=max_length, # Set if max length is defined
attn_mechanism=AttentionMechanisms.SPLASH # Matches the mechanism used by this model
),
dtype=jnp.float16, # Or jnp.bfloat16 - Computation data type
param_dtype=jnp.float16, # Or jnp.bfloat16 - Parameter data type
precision=lax.Precision("fastest"), # Like "default", "fastest", "high", "highest"
auto_shard_model=True, # Auto-shard across available devices
)
Supported Tasks
The primary task for this model is TaskType.CAUSAL_LM. Further specific supported tasks are not explicitly listed.
Limitations
General Limitations:
- Hardware Dependency: Performance can vary significantly based on the hardware (TPU/GPU) used.
- JAX/Flax Setup Required: The environment must support JAX/Flax for optimal use.
- Experimental Features: Some EasyDeL features (like custom kernels) may require additional configuration.
License π
EasyDeL is released under the Apache v2 license. The license for this specific model might differ; please consult the original model repository or documentation.
# Apache License 2.0 (referring to EasyDeL Framework)
# ... (Full license text usually included in the main repo) ...
Citation
If you use EasyDeL in your research or work, please cite it:
@misc{Zare Chavoshi_2023,
title={EasyDeL: An open-source library for enhancing and streamlining the training process of machine learning models},
url={https://github.com/erfanzar/EasyDeL},
author={Zare Chavoshi, Erfan},
year={2023}
}
Please also consider citing the original paper or source for the erfanzar/Marin-8B-Instruct-eformat model architecture if applicable.
- Downloads last month
- 21