AmelieSchreiber
commited on
Commit
Β·
d7c5a01
1
Parent(s):
9630065
Delete esmbind-validation-struct.ipynb
Browse files
esmbind-validation-struct.ipynb
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
{"metadata":{"kernelspec":{"language":"python","display_name":"Python 3","name":"python3"},"language_info":{"name":"python","version":"3.10.12","mimetype":"text/x-python","codemirror_mode":{"name":"ipython","version":3},"pygments_lexer":"ipython3","nbconvert_exporter":"python","file_extension":".py"}},"nbformat_minor":4,"nbformat":4,"cells":[{"cell_type":"code","source":"!pip install transformers -q\n!pip install accelerate -q\n!pip install peft -q\n!pip install datasets -q","metadata":{"_uuid":"8f2839f25d086af736a60e9eeb907d3b93b6e0e5","_cell_guid":"b1076dfc-b9ad-4769-8c92-a6c4dae69d19","execution":{"iopub.status.busy":"2023-09-24T03:32:05.937584Z","iopub.execute_input":"2023-09-24T03:32:05.938290Z","iopub.status.idle":"2023-09-24T03:32:56.415776Z","shell.execute_reply.started":"2023-09-24T03:32:05.938252Z","shell.execute_reply":"2023-09-24T03:32:56.414524Z"},"trusted":true},"execution_count":1,"outputs":[]},{"cell_type":"markdown","source":"# Testing ESMBind (ESMB) for Protein Binding Residue Prediction\n\nThis notebook is meant to test out ESM-2 LoRA models on the datasets found [here](https://github.com/hamzagamouh/pt-lm-gnn/tree/main/datasets/yu_merged) for the paper [Hybrid protein-ligand binding residue prediction with protein\nlanguage models: Does the structure matter?](https://www.biorxiv.org/content/10.1101/2023.08.11.553028v1). The models referenced in the paper are GCN, GAT, and ensemble structural models trained on PDB sequences to predict binding residues. They are the best performing models that could be found as of 17/09/23. You will need to download the datasets you want to test out from the github above and provide the file path in the code below. For your convenience, some of the performance metrics are provided below. Notice that the train/test metrics from the original data the models were trains on do not show any of the usual signs of overfitting:\n\n```python\nTrain metrics: \n {'eval_loss': 0.11367090046405792, \n 'eval_accuracy': 0.961073623713503, \n 'eval_precision': 0.3506606081587021, \n 'eval_recall': 0.9097597679932995, \n 'eval_f1': 0.5062071663690367, \n 'eval_auc': 0.9359920115129883, \n 'eval_mcc': 0.5513080553639849}\nTest metrics: \n {'eval_loss': 0.11328430473804474, \n 'eval_accuracy': 0.9604888971537066, \n 'eval_precision': 0.34630886072474065, \n 'eval_recall': 0.9135862937475725, \n 'eval_f1': 0.5022370749476722, \n 'eval_auc': 0.9375606817360377, \n 'eval_mcc': 0.5489185177475369}\n```\n\nYet the model does not seem to generalize well to the PDB datasets. ","metadata":{}},{"cell_type":"code","source":"import pandas as pd\n\n# Load the dataset\ndata_df = pd.read_csv(\"/kaggle/input/binding-sites-struct-2/ZN_Training.txt\", delimiter=';')\n\n# Display the first few rows of the dataframe to understand its structure\ndata_df.head()","metadata":{"execution":{"iopub.status.busy":"2023-09-24T03:53:04.393908Z","iopub.execute_input":"2023-09-24T03:53:04.394318Z","iopub.status.idle":"2023-09-24T03:53:04.427276Z","shell.execute_reply.started":"2023-09-24T03:53:04.394286Z","shell.execute_reply":"2023-09-24T03:53:04.426211Z"},"trusted":true},"execution_count":22,"outputs":[{"execution_count":22,"output_type":"execute_result","data":{"text/plain":" pdb_id chain_id binding_residues \\\n0 2E7Y A D52 H53 D190 H236 H48 H50 H134 D190 \n1 2YV5 A C248 H250 C256 \n2 2I9W A C8 C10 C28 C29 \n3 3IR9 A C79 C82 C102 C105 \n4 2OKL A Q66 C111 H154 H158 \n\n sequence \n0 MNIIGFSKALFSTWIYYSPERILFDAGEGVSTTLGSKVYAFKYVFL... \n1 MGKKELKRGLVVDREAQMIGVYLFEDGKTYRGIPRGKVLKKTKIYA... \n2 MLFSIQTCPCQINPALNAVSTPLLYQDCCQPYHDGLYNQAIRADTA... \n3 AYTDESGLSELVNAAGEKLQDLELMGQKNAVRDFFKELIADSGKVA... \n4 HMLTMKDVIREGDPILRNVAEEVSLPASEEDTTTLKEMIEFVINSQ... ","text/html":"<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead th {\n text-align: right;\n }\n</style>\n<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th></th>\n <th>pdb_id</th>\n <th>chain_id</th>\n <th>binding_residues</th>\n <th>sequence</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th>0</th>\n <td>2E7Y</td>\n <td>A</td>\n <td>D52 H53 D190 H236 H48 H50 H134 D190</td>\n <td>MNIIGFSKALFSTWIYYSPERILFDAGEGVSTTLGSKVYAFKYVFL...</td>\n </tr>\n <tr>\n <th>1</th>\n <td>2YV5</td>\n <td>A</td>\n <td>C248 H250 C256</td>\n <td>MGKKELKRGLVVDREAQMIGVYLFEDGKTYRGIPRGKVLKKTKIYA...</td>\n </tr>\n <tr>\n <th>2</th>\n <td>2I9W</td>\n <td>A</td>\n <td>C8 C10 C28 C29</td>\n <td>MLFSIQTCPCQINPALNAVSTPLLYQDCCQPYHDGLYNQAIRADTA...</td>\n </tr>\n <tr>\n <th>3</th>\n <td>3IR9</td>\n <td>A</td>\n <td>C79 C82 C102 C105</td>\n <td>AYTDESGLSELVNAAGEKLQDLELMGQKNAVRDFFKELIADSGKVA...</td>\n </tr>\n <tr>\n <th>4</th>\n <td>2OKL</td>\n <td>A</td>\n <td>Q66 C111 H154 H158</td>\n <td>HMLTMKDVIREGDPILRNVAEEVSLPASEEDTTTLKEMIEFVINSQ...</td>\n </tr>\n </tbody>\n</table>\n</div>"},"metadata":{}}]},{"cell_type":"code","source":"# Define a function to convert binding residues to binary labels\ndef binding_residues_to_labels(row):\n sequence = row['sequence']\n binding_residues = row['binding_residues']\n\n # Initialize a list with zeros\n labels = [0] * len(sequence)\n\n # If binding_residues is not NaN, mark the binding residues in the labels list with 1\n if isinstance(binding_residues, str):\n # Get the indices of the binding residues\n binding_residues_indices = [int(residue[1:]) - 1 for residue in binding_residues.split()]\n\n # Mark the binding residues in the labels list with 1\n for idx in binding_residues_indices:\n if idx < len(labels):\n labels[idx] = 1\n\n return labels\n\n# Apply the function to each row in the DataFrame to get the binary labels\ndata_df['binding_labels'] = data_df.apply(binding_residues_to_labels, axis=1)\n\n# Display the first few rows of the DataFrame\ndata_df.head()\n\n","metadata":{"execution":{"iopub.status.busy":"2023-09-24T03:53:04.742061Z","iopub.execute_input":"2023-09-24T03:53:04.742701Z","iopub.status.idle":"2023-09-24T03:53:04.787931Z","shell.execute_reply.started":"2023-09-24T03:53:04.742661Z","shell.execute_reply":"2023-09-24T03:53:04.786862Z"},"trusted":true},"execution_count":23,"outputs":[{"execution_count":23,"output_type":"execute_result","data":{"text/plain":" pdb_id chain_id binding_residues \\\n0 2E7Y A D52 H53 D190 H236 H48 H50 H134 D190 \n1 2YV5 A C248 H250 C256 \n2 2I9W A C8 C10 C28 C29 \n3 3IR9 A C79 C82 C102 C105 \n4 2OKL A Q66 C111 H154 H158 \n\n sequence \\\n0 MNIIGFSKALFSTWIYYSPERILFDAGEGVSTTLGSKVYAFKYVFL... \n1 MGKKELKRGLVVDREAQMIGVYLFEDGKTYRGIPRGKVLKKTKIYA... \n2 MLFSIQTCPCQINPALNAVSTPLLYQDCCQPYHDGLYNQAIRADTA... \n3 AYTDESGLSELVNAAGEKLQDLELMGQKNAVRDFFKELIADSGKVA... \n4 HMLTMKDVIREGDPILRNVAEEVSLPASEEDTTTLKEMIEFVINSQ... \n\n binding_labels \n0 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n1 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n2 [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, ... \n3 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n4 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... ","text/html":"<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead th {\n text-align: right;\n }\n</style>\n<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th></th>\n <th>pdb_id</th>\n <th>chain_id</th>\n <th>binding_residues</th>\n <th>sequence</th>\n <th>binding_labels</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th>0</th>\n <td>2E7Y</td>\n <td>A</td>\n <td>D52 H53 D190 H236 H48 H50 H134 D190</td>\n <td>MNIIGFSKALFSTWIYYSPERILFDAGEGVSTTLGSKVYAFKYVFL...</td>\n <td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...</td>\n </tr>\n <tr>\n <th>1</th>\n <td>2YV5</td>\n <td>A</td>\n <td>C248 H250 C256</td>\n <td>MGKKELKRGLVVDREAQMIGVYLFEDGKTYRGIPRGKVLKKTKIYA...</td>\n <td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...</td>\n </tr>\n <tr>\n <th>2</th>\n <td>2I9W</td>\n <td>A</td>\n <td>C8 C10 C28 C29</td>\n <td>MLFSIQTCPCQINPALNAVSTPLLYQDCCQPYHDGLYNQAIRADTA...</td>\n <td>[0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, ...</td>\n </tr>\n <tr>\n <th>3</th>\n <td>3IR9</td>\n <td>A</td>\n <td>C79 C82 C102 C105</td>\n <td>AYTDESGLSELVNAAGEKLQDLELMGQKNAVRDFFKELIADSGKVA...</td>\n <td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...</td>\n </tr>\n <tr>\n <th>4</th>\n <td>2OKL</td>\n <td>A</td>\n <td>Q66 C111 H154 H158</td>\n <td>HMLTMKDVIREGDPILRNVAEEVSLPASEEDTTTLKEMIEFVINSQ...</td>\n <td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...</td>\n </tr>\n </tbody>\n</table>\n</div>"},"metadata":{}}]},{"cell_type":"code","source":"# Define the maximum chunk size\nMAX_CHUNK_SIZE = 900\n\n# Function to segment sequences and labels into chunks of size <= 1022\ndef segment_into_chunks(row):\n sequence = row['sequence']\n labels = row['binding_labels']\n\n # Segment the sequence and labels into chunks of size <= 1022\n sequence_chunks = [sequence[i:i+MAX_CHUNK_SIZE] for i in range(0, len(sequence), MAX_CHUNK_SIZE)]\n label_chunks = [labels[i:i+MAX_CHUNK_SIZE] for i in range(0, len(labels), MAX_CHUNK_SIZE)]\n\n return sequence_chunks, label_chunks\n\n# Apply the function to each row in the DataFrame to get the segmented sequences and labels\ndata_df['sequence_chunks'] = None\ndata_df['label_chunks'] = None\nfor idx, row in data_df.iterrows():\n data_df.at[idx, 'sequence_chunks'], data_df.at[idx, 'label_chunks'] = segment_into_chunks(row)\n\n# Display the first few rows of the DataFrame\ndata_df[['pdb_id', 'chain_id', 'sequence_chunks', 'label_chunks']].head()\n\n","metadata":{"execution":{"iopub.status.busy":"2023-09-24T03:53:05.187750Z","iopub.execute_input":"2023-09-24T03:53:05.188428Z","iopub.status.idle":"2023-09-24T03:53:05.354284Z","shell.execute_reply.started":"2023-09-24T03:53:05.188392Z","shell.execute_reply":"2023-09-24T03:53:05.353228Z"},"trusted":true},"execution_count":24,"outputs":[{"execution_count":24,"output_type":"execute_result","data":{"text/plain":" pdb_id chain_id sequence_chunks \\\n0 2E7Y A [MNIIGFSKALFSTWIYYSPERILFDAGEGVSTTLGSKVYAFKYVF... \n1 2YV5 A [MGKKELKRGLVVDREAQMIGVYLFEDGKTYRGIPRGKVLKKTKIY... \n2 2I9W A [MLFSIQTCPCQINPALNAVSTPLLYQDCCQPYHDGLYNQAIRADT... \n3 3IR9 A [AYTDESGLSELVNAAGEKLQDLELMGQKNAVRDFFKELIADSGKV... \n4 2OKL A [HMLTMKDVIREGDPILRNVAEEVSLPASEEDTTTLKEMIEFVINS... \n\n label_chunks \n0 [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,... \n1 [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,... \n2 [[0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0,... \n3 [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,... \n4 [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,... ","text/html":"<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead th {\n text-align: right;\n }\n</style>\n<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th></th>\n <th>pdb_id</th>\n <th>chain_id</th>\n <th>sequence_chunks</th>\n <th>label_chunks</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th>0</th>\n <td>2E7Y</td>\n <td>A</td>\n <td>[MNIIGFSKALFSTWIYYSPERILFDAGEGVSTTLGSKVYAFKYVF...</td>\n <td>[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...</td>\n </tr>\n <tr>\n <th>1</th>\n <td>2YV5</td>\n <td>A</td>\n <td>[MGKKELKRGLVVDREAQMIGVYLFEDGKTYRGIPRGKVLKKTKIY...</td>\n <td>[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...</td>\n </tr>\n <tr>\n <th>2</th>\n <td>2I9W</td>\n <td>A</td>\n <td>[MLFSIQTCPCQINPALNAVSTPLLYQDCCQPYHDGLYNQAIRADT...</td>\n <td>[[0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0,...</td>\n </tr>\n <tr>\n <th>3</th>\n <td>3IR9</td>\n <td>A</td>\n <td>[AYTDESGLSELVNAAGEKLQDLELMGQKNAVRDFFKELIADSGKV...</td>\n <td>[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...</td>\n </tr>\n <tr>\n <th>4</th>\n <td>2OKL</td>\n <td>A</td>\n <td>[HMLTMKDVIREGDPILRNVAEEVSLPASEEDTTTLKEMIEFVINS...</td>\n <td>[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...</td>\n </tr>\n </tbody>\n</table>\n</div>"},"metadata":{}}]},{"cell_type":"code","source":"from transformers import AutoModelForTokenClassification, AutoTokenizer\nfrom peft import PeftModel\nimport torch\n\ndef get_predictions(protein_sequence):\n # Path to the saved LoRA model\n model_path = \"AmelieSchreiber/esm2_t12_35M_lora_binding_sites_770K_v1\"\n # ESM2 base model\n base_model_path = \"facebook/esm2_t12_35M_UR50D\"\n\n # Load the model\n base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)\n loaded_model = PeftModel.from_pretrained(base_model, model_path)\n\n # Ensure the model is in evaluation mode\n loaded_model.eval()\n\n # Load the tokenizer\n loaded_tokenizer = AutoTokenizer.from_pretrained(base_model_path)\n\n # Tokenize the sequence\n inputs = loaded_tokenizer(protein_sequence, return_tensors=\"pt\", truncation=True, max_length=1024, padding='max_length')\n\n # Run the model\n with torch.no_grad():\n logits = loaded_model(**inputs).logits\n\n # Get predictions\n tokens = loaded_tokenizer.convert_ids_to_tokens(inputs[\"input_ids\"][0]) # Convert input ids back to tokens\n predictions = torch.argmax(logits, dim=2)[0].numpy()\n\n # Define labels\n id2label = {\n 0: \"No binding site\",\n 1: \"Binding site\"\n }\n\n # Convert predictions to binary labels (1 for binding site, 0 otherwise)\n special_tokens = ['<cls>', '<pad>', '<eos>', '<unk>', '.', '-', '<null_1>', '<mask>']\n binary_predictions = [1 if id2label[pred] == \"Binding site\" else 0 for token, pred in zip(tokens, predictions) if token not in special_tokens]\n\n return binary_predictions\n\n# Use the function to get predictions for a test sequence\ntest_sequence = \"MAVPETRPNHTIYINNLNEKIKKDELKKSLHAIFSRFGQILDILVSRSLKMRGQAFVIFKEVSSATNALRSMQGFPFYDKPMRIQYAKTDSDIIAKMKGT\"\nprint(get_predictions(test_sequence))\n\n","metadata":{"execution":{"iopub.status.busy":"2023-09-24T03:53:05.613511Z","iopub.execute_input":"2023-09-24T03:53:05.614181Z","iopub.status.idle":"2023-09-24T03:53:10.741727Z","shell.execute_reply.started":"2023-09-24T03:53:05.614147Z","shell.execute_reply":"2023-09-24T03:53:10.740750Z"},"trusted":true},"execution_count":25,"outputs":[{"name":"stderr","text":"Some weights of EsmForTokenClassification were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['classifier.bias', 'classifier.weight']\nYou should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n","output_type":"stream"},{"name":"stdout","text":"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n","output_type":"stream"}]},{"cell_type":"code","source":"from transformers import AutoModelForTokenClassification, AutoTokenizer\nfrom peft import PeftModel\nimport torch\nfrom tqdm import tqdm\nfrom tqdm.notebook import tqdm as tqdm_notebook # for notebook-compatible progress bars\nfrom sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, matthews_corrcoef\n\n# Check if a GPU is available and if not, use a CPU\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\ndef get_predictions(protein_sequence, loaded_model, loaded_tokenizer):\n # Tokenize the sequence\n inputs = loaded_tokenizer(protein_sequence, return_tensors=\"pt\", truncation=True, max_length=1000, padding='max_length')\n\n # Move the inputs to the GPU if available\n inputs = {name: tensor.to(device) for name, tensor in inputs.items()}\n\n # Run the model\n with torch.no_grad():\n logits = loaded_model(**inputs).logits\n\n # Get predictions\n tokens = loaded_tokenizer.convert_ids_to_tokens(inputs[\"input_ids\"][0].cpu()) # Convert input ids back to tokens\n predictions = torch.argmax(logits, dim=2)[0].cpu().numpy() # Move logits to CPU before converting to numpy\n\n # Define labels\n id2label = {\n 0: \"No binding site\",\n 1: \"Binding site\"\n }\n\n # Convert predictions to binary labels (1 for binding site, 0 otherwise)\n special_tokens = ['<cls>', '<pad>', '<eos>', '<unk>', '.', '-', '<null_1>', '<mask>']\n binary_predictions = [1 if id2label[pred] == \"Binding site\" else 0 for token, pred in zip(tokens, predictions) if token not in special_tokens]\n\n return binary_predictions\n\n# Load the model and tokenizer\nbase_model_path = \"facebook/esm2_t12_35M_UR50D\"\nmodel_path = \"AmelieSchreiber/esm2_t12_35M_lora_binding_sites_770K_v1\"\nloaded_model = PeftModel.from_pretrained(AutoModelForTokenClassification.from_pretrained(base_model_path), model_path)\nloaded_model.eval()\nloaded_model.to(device) # Move the model to the GPU\nloaded_tokenizer = AutoTokenizer.from_pretrained(base_model_path)\n\n# Set up tqdm for pandas\ntqdm.pandas(desc=\"Processing rows\")\n\ndef get_chunk_predictions(row):\n global loaded_model, loaded_tokenizer\n sequence_chunks = row['sequence_chunks']\n predictions = [get_predictions(chunk, loaded_model, loaded_tokenizer) for chunk in sequence_chunks]\n return predictions\n\n# Apply the function with a progress bar using tqdm.pandas\ndata_df['predictions_chunks'] = data_df.progress_apply(get_chunk_predictions, axis=1)\n\n# Flatten the lists of labels and predictions to calculate metrics\ntrue_labels_flat = [label for sublist in data_df['label_chunks'].tolist() for subsublist in sublist for label in subsublist]\npredictions_flat = [label for sublist in data_df['predictions_chunks'].tolist() for subsublist in sublist for label in subsublist]\n\n# Calculate the metrics\naccuracy = accuracy_score(true_labels_flat, predictions_flat)\nprecision = precision_score(true_labels_flat, predictions_flat)\nrecall = recall_score(true_labels_flat, predictions_flat)\nf1 = f1_score(true_labels_flat, predictions_flat)\nauc = roc_auc_score(true_labels_flat, predictions_flat)\nmcc = matthews_corrcoef(true_labels_flat, predictions_flat)\n\n# Print the metrics\nprint(f'Accuracy: {accuracy:.4f}')\nprint(f'Precision: {precision:.4f}')\nprint(f'Recall: {recall:.4f}')\nprint(f'F1 Score: {f1:.4f}')\nprint(f'AUC: {auc:.4f}')\nprint(f'MCC: {mcc:.4f}')\n","metadata":{"execution":{"iopub.status.busy":"2023-09-24T03:53:10.743808Z","iopub.execute_input":"2023-09-24T03:53:10.744192Z","iopub.status.idle":"2023-09-24T03:54:07.450439Z","shell.execute_reply.started":"2023-09-24T03:53:10.744155Z","shell.execute_reply":"2023-09-24T03:54:07.449361Z"},"trusted":true},"execution_count":26,"outputs":[{"name":"stderr","text":"Some weights of EsmForTokenClassification were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['classifier.bias', 'classifier.weight']\nYou should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\nProcessing rows: 100%|ββββββββββ| 1152/1152 [00:52<00:00, 21.80it/s]\n","output_type":"stream"},{"name":"stdout","text":"Accuracy: 0.9171\nPrecision: 0.0381\nRecall: 0.1915\nF1 Score: 0.0636\nAUC: 0.5597\nMCC: 0.0550\n","output_type":"stream"}]},{"cell_type":"markdown","source":"### ADP_Training\n```python\nAccuracy: 0.9212\nPrecision: 0.2112\nRecall: 0.4052\nF1 Score: 0.2777\nAUC: 0.6732\nMCC: 0.2547\n```\n\n### AMP_Training\n```python\nAccuracy: 0.9054\nPrecision: 0.1444\nRecall: 0.3481\nF1 Score: 0.2041\nAUC: 0.6368\nMCC: 0.1809\n```\n\n### ATP_Training\n```python\nAccuracy: 0.9237\nPrecision: 0.2270\nRecall: 0.3754\nF1 Score: 0.2829\nAUC: 0.6610\nMCC: 0.2539\n```\n\n### CA_Training\n```python\nAccuracy: 0.9040\nPrecision: 0.0367\nRecall: 0.1866\nF1 Score: 0.0613\nAUC: 0.5514\nMCC: 0.0473\n```\n\n### DNA_Training\n```python\nAccuracy: 0.8764\nPrecision: 0.1943\nRecall: 0.1549\nF1 Score: 0.1724\nAUC: 0.5484\nMCC: 0.1073\n```\n\n### FE_Training\n```python\nAccuracy: 0.9286\nPrecision: 0.0584\nRecall: 0.2298\nF1 Score: 0.0932\nAUC: 0.5849\nMCC: 0.0877\n```\n\n### GDP_Training\n```python\nAccuracy: 0.9154\nPrecision: 0.2475\nRecall: 0.5395\nF1 Score: 0.3393\nAUC: 0.7353\nMCC: 0.3270\n```\n\n### GTP_Training\n```python\nAccuracy: 0.9220\nPrecision: 0.2032\nRecall: 0.4443\nF1 Score: 0.2789\nAUC: 0.6915\nMCC: 0.2646\n```\n\n### HEME_Training\n```python\nAccuracy: 0.8807\nPrecision: 0.1986\nRecall: 0.1564\nF1 Score: 0.1750\nAUC: 0.5504\nMCC: 0.1126\n```\n\n### MG_Training\n\n```python\nAccuracy: 0.9171\nPrecision: 0.0417\nRecall: 0.3013\nF1 Score: 0.0733\nAUC: 0.6126\nMCC: 0.0868\n```\n\n### MN_Training\n\n```python\nAccuracy: 0.9190\nPrecision: 0.0579\nRecall: 0.3382\nF1 Score: 0.0989\nAUC: 0.6325\nMCC: 0.1133\n```\n\n### ZN_Training\n\n```python\nAccuracy: 0.9171\nPrecision: 0.0381\nRecall: 0.1915\nF1 Score: 0.0636\nAUC: 0.5597\nMCC: 0.0550\n```\n","metadata":{}},{"cell_type":"code","source":"","metadata":{},"execution_count":null,"outputs":[]}]}
|
|
|
|