Glebs Vinarskis
commited on
Commit
·
26b1bda
1
Parent(s):
f5dc74f
Initial commit including model and configuration
Browse files- __init__.py +0 -0
- config.json +17 -0
- configuration_stacked.py +32 -0
- lang_ident.py +40 -0
- modeling_stacked.py +159 -0
- push_to_hf.py +181 -0
- test.py +16 -0
__init__.py
ADDED
File without changes
|
config.json
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"auto_map": {
|
3 |
+
"AutoConfig": "configuration_stacked.ImpressoConfig"
|
4 |
+
},
|
5 |
+
"custom_pipelines": {
|
6 |
+
"lang-ident": {
|
7 |
+
"impl": "lang_ident.LangIdentPipeline",
|
8 |
+
"pt": [
|
9 |
+
"ExtendedMultitaskModelForTokenClassification"
|
10 |
+
],
|
11 |
+
"tf": []
|
12 |
+
}
|
13 |
+
},
|
14 |
+
"filename": "LID-40-3-2000000-1-4.bin",
|
15 |
+
"model_type": "floret",
|
16 |
+
"transformers_version": "4.45.2"
|
17 |
+
}
|
configuration_stacked.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PretrainedConfig
|
2 |
+
import os
|
3 |
+
|
4 |
+
|
5 |
+
class ImpressoConfig(PretrainedConfig):
|
6 |
+
model_type = "floret"
|
7 |
+
|
8 |
+
def __init__(self, filename="LID-40-3-2000000-1-4.bin", **kwargs):
|
9 |
+
super().__init__(**kwargs)
|
10 |
+
self.filename = filename
|
11 |
+
|
12 |
+
@classmethod
|
13 |
+
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
14 |
+
# Bypass JSON loading and create config directly
|
15 |
+
print(f"Loading ImpressoConfig from {pretrained_model_name_or_path}")
|
16 |
+
print(os.getcwd())
|
17 |
+
config = cls(filename="LID-40-3-2000000-1-4.bin", **kwargs)
|
18 |
+
return config
|
19 |
+
|
20 |
+
|
21 |
+
# Register the configuration with the transformers library
|
22 |
+
ImpressoConfig.register_for_auto_class()
|
23 |
+
|
24 |
+
# Register the custom pipeline
|
25 |
+
# PIPELINE_REGISTRY.register_pipeline(
|
26 |
+
# task="lang-ident",
|
27 |
+
# pipeline_class=LangIdentPipeline,
|
28 |
+
# model=AutoModelForSequenceClassification,
|
29 |
+
# tokenizer=AutoTokenizer,
|
30 |
+
# )
|
31 |
+
#
|
32 |
+
# print("Custom pipeline 'lang-ident' registered successfully.")
|
lang_ident.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import Pipeline
|
2 |
+
|
3 |
+
|
4 |
+
class LangIdentPipeline(Pipeline):
|
5 |
+
|
6 |
+
def _sanitize_parameters(self, **kwargs):
|
7 |
+
preprocess_kwargs = {}
|
8 |
+
if "text" in kwargs:
|
9 |
+
preprocess_kwargs["text"] = kwargs["text"]
|
10 |
+
return preprocess_kwargs, {}, {}
|
11 |
+
|
12 |
+
def preprocess(self, text, **kwargs):
|
13 |
+
print("this is preprocessing:")
|
14 |
+
print(text)
|
15 |
+
return text
|
16 |
+
|
17 |
+
def _forward(self, text):
|
18 |
+
# Extract label and confidence
|
19 |
+
predictions, probabilities = self.model.predict([text], k=1)
|
20 |
+
|
21 |
+
label = predictions[0][0].replace("__label__", "") # Remove __label__ prefix
|
22 |
+
confidence = float(
|
23 |
+
probabilities[0][0]
|
24 |
+
) # Convert to float for JSON serialization
|
25 |
+
|
26 |
+
# Format as JSON-compatible dictionary
|
27 |
+
model_output = {"label": label, "confidence": round(confidence * 100, 2)}
|
28 |
+
|
29 |
+
print("Formatted Model Output:", model_output)
|
30 |
+
return model_output
|
31 |
+
|
32 |
+
def postprocess(self, outputs, **kwargs):
|
33 |
+
return outputs
|
34 |
+
|
35 |
+
|
36 |
+
# PIPELINE_REGISTRY.register_pipeline(
|
37 |
+
# task="language-detection",
|
38 |
+
# pipeline_class=Pipeline_One,
|
39 |
+
# default={"model": None},
|
40 |
+
# )
|
modeling_stacked.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PreTrainedModel, AutoModel, AutoConfig, PretrainedConfig
|
2 |
+
import floret, torch
|
3 |
+
import os, shutil
|
4 |
+
from configuration_stacked import ImpressoConfig
|
5 |
+
from transformers.modeling_utils import (
|
6 |
+
get_parameter_device as original_get_parameter_device,
|
7 |
+
)
|
8 |
+
|
9 |
+
|
10 |
+
import torch
|
11 |
+
|
12 |
+
# Import Hugging Face dependencies
|
13 |
+
import transformers.modeling_utils
|
14 |
+
from transformers.modeling_utils import PreTrainedModel
|
15 |
+
from transformers.modeling_utils import (
|
16 |
+
get_parameter_device as original_get_parameter_device,
|
17 |
+
)
|
18 |
+
|
19 |
+
|
20 |
+
# Custom get_parameter_device
|
21 |
+
def custom_get_parameter_device(module):
|
22 |
+
"""
|
23 |
+
Custom get_parameter_device() to handle floret models.
|
24 |
+
Returns 'cpu' for FloretModelWrapper, otherwise uses the original implementation.
|
25 |
+
"""
|
26 |
+
# Check if the model is an instance of your FloretModelWrapper
|
27 |
+
if isinstance(module, FloretModelWrapper):
|
28 |
+
print(
|
29 |
+
"Custom get_parameter_device(): Detected FloretModelWrapper. Returning 'cpu'."
|
30 |
+
)
|
31 |
+
return torch.device("cpu")
|
32 |
+
|
33 |
+
# Otherwise, fall back to Hugging Face's original implementation
|
34 |
+
return original_get_parameter_device(module)
|
35 |
+
|
36 |
+
|
37 |
+
# Custom device property
|
38 |
+
@property
|
39 |
+
def custom_device(self) -> torch.device:
|
40 |
+
"""
|
41 |
+
Custom device() method to handle floret models.
|
42 |
+
Always returns torch.device('cpu') for FloretModelWrapper.
|
43 |
+
"""
|
44 |
+
# Check if the model is an instance of your FloretModelWrapper
|
45 |
+
if isinstance(self, FloretModelWrapper):
|
46 |
+
print(
|
47 |
+
"Custom device(): Detected FloretModelWrapper. Returning torch.device('cpu')."
|
48 |
+
)
|
49 |
+
return torch.device("cpu")
|
50 |
+
|
51 |
+
# Otherwise, fall back to Hugging Face's original implementation
|
52 |
+
return torch.device("cpu") # original_device.__get__(self, type(self))
|
53 |
+
|
54 |
+
|
55 |
+
# Monkey-patch get_parameter_device and device property
|
56 |
+
transformers.modeling_utils.get_parameter_device = custom_get_parameter_device
|
57 |
+
PreTrainedModel.device = custom_device
|
58 |
+
|
59 |
+
print("Monkey-patch applied: get_parameter_device and device property")
|
60 |
+
|
61 |
+
# logger = logging.getLogger(__name__)
|
62 |
+
|
63 |
+
original_device = PreTrainedModel.device
|
64 |
+
|
65 |
+
|
66 |
+
def get_info(label_map):
|
67 |
+
num_token_labels_dict = {task: len(labels) for task, labels in label_map.items()}
|
68 |
+
return num_token_labels_dict
|
69 |
+
|
70 |
+
|
71 |
+
class FloretModelWrapper:
|
72 |
+
"""
|
73 |
+
Wrapper for floret model to make it compatible with Hugging Face pipeline.
|
74 |
+
Mocks the .device attribute and passes predict() unchanged.
|
75 |
+
"""
|
76 |
+
|
77 |
+
def __init__(self, floret_model):
|
78 |
+
self.floret_model = floret_model
|
79 |
+
|
80 |
+
# Mocking the .device attribute to make Hugging Face happy
|
81 |
+
self.device = torch.device("cpu") # floret is always on CPU
|
82 |
+
|
83 |
+
def predict(self, text, k=1):
|
84 |
+
"""
|
85 |
+
Pass-through for floret's predict() method.
|
86 |
+
"""
|
87 |
+
return self.floret_model.predict(text, k=k)
|
88 |
+
|
89 |
+
|
90 |
+
class ExtendedMultitaskModelForTokenClassification(PreTrainedModel):
|
91 |
+
|
92 |
+
config_class = ImpressoConfig
|
93 |
+
# Monkey-patch get_parameter_device
|
94 |
+
|
95 |
+
def __init__(self, config: PretrainedConfig, *inputs, **kwargs):
|
96 |
+
super().__init__(config)
|
97 |
+
self.config = config
|
98 |
+
print("Doest is it even pass through here?")
|
99 |
+
print(
|
100 |
+
f"The config in ExtendedMultitaskModelForTokenClassification is: {self.config}"
|
101 |
+
)
|
102 |
+
# self.model = floret.load_model(self.config.filename)
|
103 |
+
|
104 |
+
def predict(self, text, k=1):
|
105 |
+
predictions = self.model.predict(text, k)
|
106 |
+
return predictions
|
107 |
+
|
108 |
+
@classmethod
|
109 |
+
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
110 |
+
print("Calling from_pretrained...")
|
111 |
+
|
112 |
+
# Initialize model with config
|
113 |
+
model = cls(ImpressoConfig())
|
114 |
+
|
115 |
+
# Load model using floret
|
116 |
+
print(f"---Loading model from: {model.config.filename}")
|
117 |
+
floret_model = floret.load_model(model.config.filename)
|
118 |
+
|
119 |
+
# Wrap the model to fake .device attribute
|
120 |
+
model.model = FloretModelWrapper(floret_model)
|
121 |
+
|
122 |
+
print(model.model, "device:", model.model.device)
|
123 |
+
|
124 |
+
print(f"Model loaded and wrapped from: {model.config.filename}")
|
125 |
+
|
126 |
+
return model
|
127 |
+
|
128 |
+
def save_pretrained(self, save_directory, *args, **kwargs):
|
129 |
+
# Ignore Hugging Face-specific arguments
|
130 |
+
max_shard_size = kwargs.pop("max_shard_size", None)
|
131 |
+
safe_serialization = kwargs.pop("safe_serialization", False)
|
132 |
+
|
133 |
+
# Ensure directory exists
|
134 |
+
os.makedirs(save_directory, exist_ok=True)
|
135 |
+
|
136 |
+
# Save the model file
|
137 |
+
model_file = os.path.join(save_directory, "LID-40-3-2000000-1-4.bin")
|
138 |
+
shutil.copy(self.config.filename, model_file)
|
139 |
+
|
140 |
+
# Save the config file
|
141 |
+
config_file = os.path.join(save_directory, "config.json")
|
142 |
+
self.config.save_pretrained(save_directory)
|
143 |
+
|
144 |
+
print(f"Model saved to: {save_directory}")
|
145 |
+
|
146 |
+
def get_parameter_device(module):
|
147 |
+
"""
|
148 |
+
Custom get_parameter_device() to handle floret models.
|
149 |
+
Returns 'cpu' for floret models, and falls back to the original method otherwise.
|
150 |
+
"""
|
151 |
+
# Check if the model is an instance of your FloretModelWrapper
|
152 |
+
if isinstance(module, FloretModelWrapper):
|
153 |
+
print(
|
154 |
+
"Custom get_parameter_device(): Detected FloretModelWrapper. Returning 'cpu'."
|
155 |
+
)
|
156 |
+
return "cpu"
|
157 |
+
|
158 |
+
# Otherwise, fall back to Hugging Face's original implementation
|
159 |
+
return original_get_parameter_device(module)
|
push_to_hf.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import shutil
|
3 |
+
import argparse
|
4 |
+
from transformers import (
|
5 |
+
AutoTokenizer,
|
6 |
+
AutoConfig,
|
7 |
+
AutoModelForSequenceClassification,
|
8 |
+
)
|
9 |
+
from huggingface_hub import HfApi, Repository
|
10 |
+
from transformers.pipelines import PIPELINE_REGISTRY
|
11 |
+
|
12 |
+
# import json
|
13 |
+
from configuration_stacked import ImpressoConfig
|
14 |
+
from modeling_stacked import ExtendedMultitaskModelForTokenClassification
|
15 |
+
import subprocess
|
16 |
+
from lang_ident import LangIdentPipeline
|
17 |
+
|
18 |
+
|
19 |
+
def get_latest_checkpoint(checkpoint_dir):
|
20 |
+
checkpoints = [
|
21 |
+
d
|
22 |
+
for d in os.listdir(checkpoint_dir)
|
23 |
+
if os.path.isdir(os.path.join(checkpoint_dir, d))
|
24 |
+
and d.startswith("checkpoint-")
|
25 |
+
]
|
26 |
+
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[-1]), reverse=True)
|
27 |
+
return os.path.join(checkpoint_dir, checkpoints[0])
|
28 |
+
|
29 |
+
|
30 |
+
def get_info(label_map):
|
31 |
+
num_token_labels_dict = {task: len(labels) for task, labels in label_map.items()}
|
32 |
+
return num_token_labels_dict
|
33 |
+
|
34 |
+
|
35 |
+
def push_model_to_hub(checkpoint_dir, repo_name):
|
36 |
+
# checkpoint_path = get_latest_checkpoint(checkpoint_dir)
|
37 |
+
checkpoint_path = checkpoint_dir
|
38 |
+
config = ImpressoConfig.from_pretrained(checkpoint_path)
|
39 |
+
print(config)
|
40 |
+
|
41 |
+
config.pretrained_config = ImpressoConfig.from_pretrained(config.filename)
|
42 |
+
config.save_pretrained("floret")
|
43 |
+
config = ImpressoConfig.from_pretrained("floret")
|
44 |
+
PIPELINE_REGISTRY.register_pipeline(
|
45 |
+
"lang-ident",
|
46 |
+
pipeline_class=LangIdentPipeline,
|
47 |
+
pt_model=ExtendedMultitaskModelForTokenClassification,
|
48 |
+
)
|
49 |
+
|
50 |
+
# PIPELINE_REGISTRY.register_pipeline(
|
51 |
+
# "pair-classification",
|
52 |
+
# pipeline_class=PairClassificationPipeline,
|
53 |
+
# pt_model=AutoModelForSequenceClassification,
|
54 |
+
# tf_model=TFAutoModelForSequenceClassification,
|
55 |
+
# )
|
56 |
+
|
57 |
+
config.custom_pipelines = {
|
58 |
+
"lang-ident": {
|
59 |
+
"impl": "lang_ident.LangIdentPipeline",
|
60 |
+
"pt": ["AutoModelForSequenceClassification"],
|
61 |
+
"tf": [],
|
62 |
+
}
|
63 |
+
}
|
64 |
+
model = ExtendedMultitaskModelForTokenClassification.from_pretrained(
|
65 |
+
checkpoint_path, config=config
|
66 |
+
)
|
67 |
+
|
68 |
+
local_repo_path = "lang-detect"
|
69 |
+
repo_url = HfApi().create_repo(repo_id=repo_name, exist_ok=True)
|
70 |
+
repo = Repository(local_dir=local_repo_path, clone_from=repo_url)
|
71 |
+
|
72 |
+
try:
|
73 |
+
# Try to pull the latest changes from the remote repository using subprocess
|
74 |
+
subprocess.run(["git", "pull"], check=True, cwd=local_repo_path)
|
75 |
+
except subprocess.CalledProcessError as e:
|
76 |
+
# If fast-forward is not possible, reset the local branch to match the remote branch
|
77 |
+
subprocess.run(
|
78 |
+
["git", "reset", "--hard", "origin/main"],
|
79 |
+
check=True,
|
80 |
+
cwd=local_repo_path,
|
81 |
+
)
|
82 |
+
|
83 |
+
# Copy all Python files to the local repository directory
|
84 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
85 |
+
for filename in os.listdir(current_dir):
|
86 |
+
if filename.endswith(".py") or filename.endswith(".json"):
|
87 |
+
shutil.copy(
|
88 |
+
os.path.join(current_dir, filename),
|
89 |
+
os.path.join(local_repo_path, filename),
|
90 |
+
)
|
91 |
+
|
92 |
+
ImpressoConfig.register_for_auto_class()
|
93 |
+
|
94 |
+
AutoConfig.register("floret", ImpressoConfig)
|
95 |
+
AutoModelForSequenceClassification.register(
|
96 |
+
ImpressoConfig, ExtendedMultitaskModelForTokenClassification
|
97 |
+
)
|
98 |
+
ExtendedMultitaskModelForTokenClassification.register_for_auto_class(
|
99 |
+
"AutoModelForSequenceClassification"
|
100 |
+
)
|
101 |
+
# model.save_pretrained(local_repo_path)
|
102 |
+
|
103 |
+
from transformers import AutoModelForTokenClassification, AutoTokenizer
|
104 |
+
from transformers import pipeline
|
105 |
+
|
106 |
+
# Define the model name to be used for token classification, we use the Impresso NER
|
107 |
+
# that can be found at "https://huggingface.co/impresso-project/ner-stacked-bert-multilingual"
|
108 |
+
MODEL_NAME = "Maslionok/lang-detect"
|
109 |
+
#
|
110 |
+
|
111 |
+
# # Add, commit and push the changes to the repository
|
112 |
+
subprocess.run(["git", "add", "."], check=True, cwd=local_repo_path)
|
113 |
+
subprocess.run(
|
114 |
+
["git", "commit", "-m", "Initial commit including model and configuration"],
|
115 |
+
check=True,
|
116 |
+
cwd=local_repo_path,
|
117 |
+
)
|
118 |
+
subprocess.run(["git", "push"], check=True, cwd=local_repo_path)
|
119 |
+
#
|
120 |
+
# Push the model to the hub (this includes the README template)
|
121 |
+
model.push_to_hub(repo_name)
|
122 |
+
|
123 |
+
lang_pipeline = pipeline(
|
124 |
+
"lang-ident", model=MODEL_NAME, trust_remote_code=True, device="cpu"
|
125 |
+
)
|
126 |
+
lang_pipeline.push_to_hub(MODEL_NAME)
|
127 |
+
sentence = "En l'an 1348, au plus fort des ravages de la peste noire à travers l'Europe, le Royaume de France se trouvait à la fois au bord du désespoir et face à une opportunité. À la cour du roi Philippe VI, les murs du Louvre étaient animés par les rapports sombres venus de Paris et des villes environnantes. La peste ne montrait aucun signe de répit, et le chancelier Guillaume de Nogaret, le conseiller le plus fidèle du roi, portait le lourd fardeau de gérer la survie du royaume."
|
128 |
+
#
|
129 |
+
print(lang_pipeline(sentence))
|
130 |
+
# lang_pipeline.push_to_hub(MODEL_NAME)
|
131 |
+
print(f"Model and repo pushed to: {repo_url}")
|
132 |
+
|
133 |
+
|
134 |
+
if __name__ == "__main__":
|
135 |
+
parser = argparse.ArgumentParser(description="Push NER model to Hugging Face Hub")
|
136 |
+
parser.add_argument(
|
137 |
+
"--model_type",
|
138 |
+
type=str,
|
139 |
+
required=True,
|
140 |
+
help="Type of the model (e.g., langident)",
|
141 |
+
)
|
142 |
+
parser.add_argument(
|
143 |
+
"--language",
|
144 |
+
type=str,
|
145 |
+
required=True,
|
146 |
+
help="Language of the model (e.g., multilingual)",
|
147 |
+
)
|
148 |
+
parser.add_argument(
|
149 |
+
"--checkpoint_dir",
|
150 |
+
type=str,
|
151 |
+
required=True,
|
152 |
+
default="LID-40-3-2000000-1-4.bin",
|
153 |
+
help="Directory containing checkpoint folders",
|
154 |
+
)
|
155 |
+
args = parser.parse_args()
|
156 |
+
repo_name = f"Maslionok/lang-detect"
|
157 |
+
push_model_to_hub(args.checkpoint_dir, repo_name)
|
158 |
+
|
159 |
+
# PIPELINE_REGISTRY.register_pipeline(
|
160 |
+
# "generic-ner",
|
161 |
+
# pipeline_class=MultitaskTokenClassificationPipeline,
|
162 |
+
# pt_model=ExtendedMultitaskModelForTokenClassification,
|
163 |
+
# )
|
164 |
+
# model.config.custom_pipelines = {
|
165 |
+
# "generic-ner": {
|
166 |
+
# "impl": "generic_ner.MultitaskTokenClassificationPipeline",
|
167 |
+
# "pt": ["ExtendedMultitaskModelForTokenClassification"],
|
168 |
+
# "tf": [],
|
169 |
+
# }
|
170 |
+
# }
|
171 |
+
# classifier = pipeline(
|
172 |
+
# "generic-ner", model=model, tokenizer=tokenizer, label_map=label_map
|
173 |
+
# )
|
174 |
+
# from pprint import pprint
|
175 |
+
#
|
176 |
+
# pprint(
|
177 |
+
# classifier(
|
178 |
+
# "1. Le public est averti que Charlotte née Bourgoin, femme-de Joseph Digiez, et Maurice Bourgoin, enfant mineur représenté par le sieur Jaques Charles Gicot son curateur, ont été admis par arrêt du Conseil d'Etat du 5 décembre 1797, à solliciter une renonciation générale et absolue aux biens et aux dettes présentes et futures de Jean-Baptiste Bourgoin leur père."
|
179 |
+
# )
|
180 |
+
# )
|
181 |
+
# repo.push_to_hub(commit_message="Initial commit of the trained NER model with code")
|
test.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Import necessary Python modules from the Transformers library
|
2 |
+
from transformers import AutoModelForTokenClassification, AutoTokenizer
|
3 |
+
from transformers import pipeline
|
4 |
+
# Define the model name to be used for token classification, we use the Impresso NER
|
5 |
+
# that can be found at "https://huggingface.co/impresso-project/ner-stacked-bert-multilingual"
|
6 |
+
MODEL_NAME = "emanuelaboros/lang-detect"
|
7 |
+
|
8 |
+
lang_pipeline = pipeline("lang-ident", model=MODEL_NAME,
|
9 |
+
trust_remote_code=True,
|
10 |
+
device='cpu')
|
11 |
+
|
12 |
+
sentence = "En l'an 1348, au plus fort des ravages de la peste noire à travers l'Europe, le Royaume de France se trouvait à la fois au bord du désespoir et face à une opportunité. À la cour du roi Philippe VI, les murs du Louvre étaient animés par les rapports sombres venus de Paris et des villes environnantes. La peste ne montrait aucun signe de répit, et le chancelier Guillaume de Nogaret, le conseiller le plus fidèle du roi, portait le lourd fardeau de gérer la survie du royaume."
|
13 |
+
|
14 |
+
entities = lang_pipeline(sentence)
|
15 |
+
print(entities)
|
16 |
+
|