{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/pscratch/sd/c/cyrusyc/.conda/mlip-arena/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from huggingface_hub import hf_hub_download\n",
    "from mlip_arena.models import MLIP, MLIPCalculator, ModuleMLIP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "fpath = hf_hub_download(\n",
    "    repo_id=\"cyrusyc/mace-universal\",\n",
    "    subfolder=\"pretrained\",\n",
    "    filename=\"2023-12-12-mace-128-L1_epoch-199.model\",\n",
    "    revision=None,  # TODO: Add revision\n",
    ")\n",
    "\n",
    "model = torch.load(fpath, map_location=\"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "module = ModuleMLIP(model=model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "CommitInfo(commit_url='https://huggingface.co/atomind/mace-mp-medium/commit/eb12c5387b9e655d83a4e2e10c0f0779c3745227', commit_message='Push model using huggingface_hub.', commit_description='', oid='eb12c5387b9e655d83a4e2e10c0f0779c3745227', pr_url=None, pr_revision=None, pr_num=None)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "module.save_pretrained(\n",
    "    \"mace\",\n",
    "    repo_id=\"atomind/MACE_MP_Medium\".lower().replace(\"_\", \"-\"),\n",
    "    push_to_hub=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/pscratch/sd/c/cyrusyc/.conda/mlip-arena/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "\n",
    "from mlip_arena.models.mace import MACE_MP_Medium\n",
    "import torch\n",
    "\n",
    "calc = MACE_MP_Medium(device=torch.device(\"cuda\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ScaleShiftMACE(\n",
       "  (node_embedding): LinearNodeEmbeddingBlock(\n",
       "    (linear): Linear(89x0e -> 128x0e | 11392 weights)\n",
       "  )\n",
       "  (radial_embedding): RadialEmbeddingBlock(\n",
       "    (bessel_fn): BesselBasis(r_max=6.0, num_basis=10, trainable=False)\n",
       "    (cutoff_fn): PolynomialCutoff(p=5.0, r_max=6.0)\n",
       "  )\n",
       "  (spherical_harmonics): SphericalHarmonics()\n",
       "  (atomic_energies_fn): AtomicEnergiesBlock(energies=[-3.6672, -1.3321, -3.4821, -4.7367, -7.7249, -8.4056, -7.3601, -7.2846, -4.8965, 0.0000, -2.7594, -2.8140, -4.8469, -7.6948, -6.9633, -4.6726, -2.8117, -0.0626, -2.6176, -5.3905, -7.8858, -10.2684, -8.6651, -9.2331, -8.3050, -7.0490, -5.5774, -5.1727, -3.2521, -1.2902, -3.5271, -4.7085, -3.9765, -3.8862, -2.5185, 6.7669, -2.5635, -4.9380, -10.1498, -11.8469, -12.1389, -8.7917, -8.7869, -7.7809, -6.8500, -4.8910, -2.0634, -0.6396, -2.7887, -3.8186, -3.5871, -2.8804, -1.6356, 9.8467, -2.7653, -4.9910, -8.9337, -8.7356, -8.0190, -8.2515, -7.5917, -8.1697, -13.5927, -18.5175, -7.6474, -8.1230, -7.6078, -6.8503, -7.8269, -3.5848, -7.4554, -12.7963, -14.1081, -9.3549, -11.3875, -9.6219, -7.3244, -5.3047, -2.3801, 0.2495, -2.3240, -3.7300, -3.4388, -5.0629, -11.0246, -12.2656, -13.8556, -14.9331, -15.2828])\n",
       "  (interactions): ModuleList(\n",
       "    (0): RealAgnosticResidualInteractionBlock(\n",
       "      (linear_up): Linear(128x0e -> 128x0e | 16384 weights)\n",
       "      (conv_tp): TensorProduct(128x0e x 1x0e+1x1o+1x2e+1x3o -> 128x0e+128x1o+128x2e+128x3o | 512 paths | 512 weights)\n",
       "      (conv_tp_weights): FullyConnectedNet[10, 64, 64, 64, 512]\n",
       "      (linear): Linear(128x0e+128x1o+128x2e+128x3o -> 128x0e+128x1o+128x2e+128x3o | 65536 weights)\n",
       "      (skip_tp): FullyConnectedTensorProduct(128x0e x 89x0e -> 128x0e+128x1o | 1458176 paths | 1458176 weights)\n",
       "      (reshape): reshape_irreps()\n",
       "    )\n",
       "    (1): RealAgnosticResidualInteractionBlock(\n",
       "      (linear_up): Linear(128x0e+128x1o -> 128x0e+128x1o | 32768 weights)\n",
       "      (conv_tp): TensorProduct(128x0e+128x1o x 1x0e+1x1o+1x2e+1x3o -> 256x0e+384x1o+384x2e+256x3o | 1280 paths | 1280 weights)\n",
       "      (conv_tp_weights): FullyConnectedNet[10, 64, 64, 64, 1280]\n",
       "      (linear): Linear(256x0e+384x1o+384x2e+256x3o -> 128x0e+128x1o+128x2e+128x3o | 163840 weights)\n",
       "      (skip_tp): FullyConnectedTensorProduct(128x0e+128x1o x 89x0e -> 128x0e | 1458176 paths | 1458176 weights)\n",
       "      (reshape): reshape_irreps()\n",
       "    )\n",
       "  )\n",
       "  (products): ModuleList(\n",
       "    (0): EquivariantProductBasisBlock(\n",
       "      (symmetric_contractions): SymmetricContraction(\n",
       "        (contractions): ModuleList(\n",
       "          (0): Contraction(\n",
       "            (contractions_weighting): ModuleList(\n",
       "              (0-1): 2 x GraphModule()\n",
       "            )\n",
       "            (contractions_features): ModuleList(\n",
       "              (0-1): 2 x GraphModule()\n",
       "            )\n",
       "            (weights): ParameterList(\n",
       "                (0): Parameter containing: [torch.float64 of size 89x4x128 (cuda:0)]\n",
       "                (1): Parameter containing: [torch.float64 of size 89x1x128 (cuda:0)]\n",
       "            )\n",
       "            (graph_opt_main): GraphModule()\n",
       "          )\n",
       "          (1): Contraction(\n",
       "            (contractions_weighting): ModuleList(\n",
       "              (0-1): 2 x GraphModule()\n",
       "            )\n",
       "            (contractions_features): ModuleList(\n",
       "              (0-1): 2 x GraphModule()\n",
       "            )\n",
       "            (weights): ParameterList(\n",
       "                (0): Parameter containing: [torch.float64 of size 89x6x128 (cuda:0)]\n",
       "                (1): Parameter containing: [torch.float64 of size 89x1x128 (cuda:0)]\n",
       "            )\n",
       "            (graph_opt_main): GraphModule()\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "      (linear): Linear(128x0e+128x1o -> 128x0e+128x1o | 32768 weights)\n",
       "    )\n",
       "    (1): EquivariantProductBasisBlock(\n",
       "      (symmetric_contractions): SymmetricContraction(\n",
       "        (contractions): ModuleList(\n",
       "          (0): Contraction(\n",
       "            (contractions_weighting): ModuleList(\n",
       "              (0-1): 2 x GraphModule()\n",
       "            )\n",
       "            (contractions_features): ModuleList(\n",
       "              (0-1): 2 x GraphModule()\n",
       "            )\n",
       "            (weights): ParameterList(\n",
       "                (0): Parameter containing: [torch.float64 of size 89x4x128 (cuda:0)]\n",
       "                (1): Parameter containing: [torch.float64 of size 89x1x128 (cuda:0)]\n",
       "            )\n",
       "            (graph_opt_main): GraphModule()\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "      (linear): Linear(128x0e -> 128x0e | 16384 weights)\n",
       "    )\n",
       "  )\n",
       "  (readouts): ModuleList(\n",
       "    (0): LinearReadoutBlock(\n",
       "      (linear): Linear(128x0e+128x1o -> 1x0e | 128 weights)\n",
       "    )\n",
       "    (1): NonLinearReadoutBlock(\n",
       "      (linear_1): Linear(128x0e -> 16x0e | 2048 weights)\n",
       "      (non_linearity): Activation [x] (16x0e -> 16x0e)\n",
       "      (linear_2): Linear(16x0e -> 1x0e | 16 weights)\n",
       "    )\n",
       "  )\n",
       "  (scale_shift): ScaleShiftBlock(scale=0.804154, shift=0.164097)\n",
       ")"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "calc.model\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mlip_arena.models import MLIP\n",
    "\n",
    "model = MLIP.from_pretrained(\"atomind/mace-mp-medium\", map_location=\"cuda\", revision=\"main\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<generator object Module.modules at 0x7ff33915f920>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.modules()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "ename": "AttributeError",
     "evalue": "MLIP has no attribute `model`",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[8], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_submodule\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mmodel\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/pscratch/sd/c/cyrusyc/.conda/mlip-arena/lib/python3.11/site-packages/torch/nn/modules/module.py:681\u001b[0m, in \u001b[0;36mModule.get_submodule\u001b[0;34m(self, target)\u001b[0m\n\u001b[1;32m    678\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m item \u001b[38;5;129;01min\u001b[39;00m atoms:\n\u001b[1;32m    680\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mhasattr\u001b[39m(mod, item):\n\u001b[0;32m--> 681\u001b[0m         \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m(mod\u001b[38;5;241m.\u001b[39m_get_name() \u001b[38;5;241m+\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m has no \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    682\u001b[0m                              \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mattribute `\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m+\u001b[39m item \u001b[38;5;241m+\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m`\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m    684\u001b[0m     mod \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mgetattr\u001b[39m(mod, item)\n\u001b[1;32m    686\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(mod, torch\u001b[38;5;241m.\u001b[39mnn\u001b[38;5;241m.\u001b[39mModule):\n",
      "\u001b[0;31mAttributeError\u001b[0m: MLIP has no attribute `model`"
     ]
    }
   ],
   "source": [
    "model.get_submodule(\"model\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for name, param in model.named_parameters():\n",
    "    print(name, param.data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(module)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}