The outputs of ESMC 600M and your model are not fully consistent
ESMc (logits=ForwardTrackData(sequence=tensor([[[-22.1250, -22.0000, -22.0000, 17.1250, 18.3750, 18.7500, 18.6250,
18.6250, 17.6250, 16.7500, 16.5000, 16.3750, 17.6250, 16.2500,
18.0000, 17.1250, 14.3125, 14.8125, 17.8750, 14.9375, 17.8750,
14.8750, 16.6250, 16.8750, 18.1250, 10.1250, 10.0625, 0.5273,
-8.4375, -22.0000, -22.1250, -22.0000, -22.0000, -21.8750, -22.1250,
-22.0000, -22.1250, -22.0000, -22.0000, -22.0000, -22.0000, -21.8750,
-22.0000, -22.0000, -22.1250, -21.8750, -22.1250, -22.0000, -21.8750,
-22.1250, -22.0000, -22.0000, -21.8750, -22.0000, -22.0000, -22.0000,
-22.0000, -22.0000, -22.0000, -22.0000, -22.0000, -22.0000, -22.1250,
-22.1250],
[-27.3750, -27.2500, -27.2500, 6.3438, 18.3750, 17.2500, 16.0000,
18.8750, 15.8125, 15.8750, 15.4375, 15.9375, 15.4375, 15.4375,
15.8750, 14.8750, 14.9375, 14.5625, 15.0625, 14.2500, 20.3750,
14.0625, 14.1250, 14.4375, 12.6875, 7.8125, 3.9844, 6.4062,
-15.5000, -27.3750, -27.3750, -27.2500, -27.3750, -27.2500, -27.2500,
-27.3750, -27.3750, -27.2500, -27.3750, -27.2500, -27.2500, -27.2500,
-27.3750, -27.2500, -27.2500, -27.3750, -27.2500, -27.2500, -27.2500,
-27.2500, -27.2500, -27.2500, -27.1250, -27.2500, -27.2500, -27.2500,
-27.2500, -27.2500, -27.2500, -27.2500, -27.2500, -27.2500, -27.2500,
-27.2500],
[-26.0000, -25.8750, -26.0000, 21.1250, 27.6250, 28.3750, 27.5000,
27.6250, 27.2500, 27.1250, 27.3750, 27.2500, 26.8750, 26.7500,
27.1250, 26.5000, 26.2500, 26.0000, 26.2500, 25.6250, 26.1250,
25.5000, 25.5000, 25.7500, 22.3750, 3.9062, 3.3438, 2.1094,
-10.2500, -26.0000, -26.0000, -25.8750, -26.0000, -25.8750, -25.8750,
-26.0000, -26.0000, -25.8750, -26.0000, -25.8750, -25.8750, -25.8750,
-26.0000, -25.8750, -25.8750, -25.8750, -25.8750, -25.8750, -26.0000,
-25.8750, -25.8750, -25.8750, -25.7500, -25.8750, -25.8750, -25.8750,
-25.8750, -25.7500, -26.0000, -25.8750, -25.8750, -25.8750, -25.8750,
-25.8750],
[-25.8750, -25.6250, -25.8750, 21.2500, 27.8750, 28.8750, 28.0000,
28.0000, 27.6250, 27.3750, 27.6250, 27.6250, 27.3750, 27.0000,
27.2500, 26.6250, 26.3750, 26.3750, 26.5000, 25.8750, 26.5000,
26.0000, 25.6250, 26.0000, 22.8750, 3.7344, 3.2812, 1.8047,
-10.1250, -25.8750, -25.8750, -25.7500, -25.8750, -25.7500, -25.7500,
-25.7500, -25.8750, -25.6250, -25.8750, -25.7500, -25.7500, -25.7500,
-25.7500, -25.7500, -25.7500, -25.7500, -25.7500, -25.7500, -25.7500,
-25.7500, -25.7500, -25.7500, -25.6250, -25.7500, -25.7500, -25.7500,
-25.7500, -25.6250, -25.7500, -25.7500, -25.7500, -25.7500, -25.7500,
-25.7500],
[-26.5000, -26.3750, -26.5000, 22.2500, 29.0000, 30.0000, 29.0000,
29.0000, 28.6250, 28.3750, 28.8750, 28.6250, 28.2500, 28.0000,
28.2500, 27.7500, 27.6250, 27.3750, 27.5000, 26.8750, 27.5000,
27.0000, 27.0000, 27.1250, 24.1250, 1.7031, 2.1250, 0.3672,
-10.5625, -26.5000, -26.5000, -26.3750, -26.5000, -26.3750, -26.3750,
-26.5000, -26.5000, -26.3750, -26.5000, -26.3750, -26.3750, -26.3750,
-26.5000, -26.5000, -26.5000, -26.5000, -26.3750, -26.5000, -26.5000,
-26.5000, -26.3750, -26.5000, -26.2500, -26.5000, -26.5000, -26.5000,
-26.3750, -26.3750, -26.5000, -26.3750, -26.5000, -26.3750, -26.5000,
-26.5000],
[-22.7500, -22.6250, -22.7500, 23.2500, 22.1250, 23.6250, 22.3750,
22.2500, 21.7500, 21.8750, 22.2500, 21.5000, 21.1250, 21.0000,
21.7500, 20.8750, 21.0000, 20.2500, 20.5000, 19.8750, 20.6250,
19.8750, 20.6250, 20.1250, 18.1250, 7.1875, 7.7188, 5.3125,
-9.8750, -22.7500, -22.7500, -22.7500, -22.8750, -22.7500, -22.7500,
-22.7500, -22.7500, -22.6250, -22.7500, -22.7500, -22.7500, -22.7500,
-22.7500, -22.7500, -22.7500, -22.7500, -22.7500, -22.7500, -22.7500,
-22.7500, -22.7500, -22.7500, -22.6250, -22.7500, -22.7500, -22.7500,
-22.6250, -22.6250, -22.7500, -22.7500, -22.7500, -22.6250, -22.7500,
-22.7500],
[-22.2500, -22.1250, -22.2500, 24.6250, 22.8750, 25.3750, 23.7500,
23.2500, 23.6250, 23.0000, 23.5000, 23.2500, 22.2500, 22.5000,
23.3750, 23.1250, 22.1250, 22.0000, 21.7500, 21.1250, 21.7500,
21.8750, 21.2500, 21.6250, 22.8750, 7.6875, 11.5000, 7.6562,
-9.0625, -22.3750, -22.2500, -22.2500, -22.3750, -22.2500, -22.2500,
-22.2500, -22.2500, -22.1250, -22.3750, -22.2500, -22.2500, -22.2500,
-22.2500, -22.2500, -22.2500, -22.2500, -22.2500, -22.2500, -22.2500,
-22.2500, -22.2500, -22.2500, -22.1250, -22.2500, -22.2500, -22.2500,
-22.1250, -22.1250, -22.2500, -22.1250, -22.2500, -22.2500, -22.2500,
-22.2500]]], device='cuda:0', dtype=torch.bfloat16), structure=None, secondary_structure=None, sasa=None, function=None), embeddings=tensor([[[-0.0085, 0.0046, -0.0002, ..., -0.0068, 0.0031, 0.0056],
[-0.0178, 0.0429, 0.0241, ..., 0.0105, 0.0029, 0.0272],
[-0.0167, 0.0091, 0.0185, ..., 0.0260, -0.0064, 0.0218],
...,
[-0.0226, 0.0120, 0.0300, ..., 0.0076, 0.0069, 0.0082],
[-0.0004, 0.0092, 0.0228, ..., 0.0314, -0.0065, 0.0313],
[-0.0091, 0.0163, -0.0316, ..., -0.0014, -0.0222, 0.0074]]],
device='cuda:0'), residue_annotation_logits=None, hidden_states=None)
ESMplusplusOutput (loss=None, logits=tensor([[[-22.6131, -22.4821, -22.5597, 17.2095, 18.8339, 19.0829, 18.8579,
19.0113, 17.8731, 17.0502, 16.9101, 16.8187, 17.9649, 16.5600,
18.2422, 17.4132, 14.7632, 15.2010, 18.1620, 15.2264, 18.6486,
15.2754, 16.8730, 17.1691, 18.3985, 10.2630, 10.0079, 0.2372,
-9.1027, -22.5591, -22.6386, -22.5791, -22.5049, -22.4569, -22.6224,
-22.5131, -22.5976, -22.5826, -22.5811, -22.5380, -22.5192, -22.4622,
-22.4984, -22.5364, -22.6507, -22.4547, -22.6168, -22.5447, -22.4690,
-22.5982, -22.5173, -22.5599, -22.3434, -22.5328, -22.4980, -22.5756,
-22.5273, -22.4656, -22.5479, -22.4857, -22.5090, -22.5393, -22.6139,
-22.6183],
[-28.4831, -28.3567, -28.4582, 7.2065, 19.5353, 17.8348, 17.1140,
19.8591, 17.0163, 17.0513, 16.8235, 16.9516, 16.6259, 16.8195,
17.0532, 16.2335, 16.3059, 15.9805, 16.2464, 15.7089, 21.3718,
15.5317, 15.4345, 15.7049, 13.4264, 5.6808, 2.5116, 3.4380,
-16.2709, -28.4989, -28.5415, -28.4569, -28.5197, -28.3934, -28.4740,
-28.5036, -28.4780, -28.3974, -28.5291, -28.4606, -28.4478, -28.4250,
-28.4879, -28.4249, -28.4668, -28.4702, -28.4221, -28.4475, -28.3922,
-28.4335, -28.4424, -28.4270, -28.3188, -28.4597, -28.4645, -28.4211,
-28.3837, -28.3785, -28.4568, -28.4035, -28.4443, -28.4279, -28.4683,
-28.4354],
[-26.7976, -26.6178, -26.7191, 22.1797, 30.0713, 30.4251, 29.8870,
29.9532, 29.7922, 29.5880, 29.8906, 29.7987, 29.5994, 29.2841,
29.6473, 29.2176, 28.8841, 28.8778, 28.8368, 28.3084, 28.7634,
28.3530, 27.9666, 28.2659, 24.0114, -0.8772, -1.0498, -2.3372,
-10.3790, -26.7547, -26.7898, -26.6962, -26.7614, -26.6307, -26.6968,
-26.7301, -26.7370, -26.6102, -26.7523, -26.6877, -26.6964, -26.6942,
-26.7515, -26.7087, -26.7297, -26.7117, -26.6716, -26.7051, -26.7305,
-26.7148, -26.6948, -26.7027, -26.5538, -26.7359, -26.7288, -26.7156,
-26.6430, -26.5687, -26.7441, -26.5873, -26.7225, -26.6490, -26.7439,
-26.7104],
[-27.1739, -26.9939, -27.1021, 23.0664, 31.1326, 31.7154, 31.0919,
30.9792, 30.9618, 30.6432, 31.0817, 30.8716, 30.5479, 30.3123,
30.6569, 30.1458, 30.0037, 29.7907, 29.9133, 29.3552, 29.7271,
29.4474, 29.2505, 29.3838, 25.2748, -3.1373, -2.1763, -4.4798,
-10.7190, -27.1392, -27.1712, -27.0749, -27.1418, -27.0056, -27.0678,
-27.1048, -27.1219, -26.9872, -27.1350, -27.0662, -27.0641, -27.0702,
-27.1324, -27.0911, -27.1012, -27.0834, -27.0374, -27.0862, -27.1111,
-27.0987, -27.0726, -27.0940, -26.9264, -27.1203, -27.1108, -27.1094,
-27.0210, -26.9409, -27.1313, -26.9557, -27.0992, -27.0298, -27.1288,
-27.0911],
[-23.8997, -23.7603, -23.8484, 26.2622, 24.8136, 25.5235, 24.8515,
24.5291, 24.4971, 24.3121, 24.9175, 24.1075, 23.9324, 23.7328,
24.3812, 23.6493, 23.7462, 23.1297, 23.3478, 22.8957, 23.1977,
22.8750, 23.1623, 23.0427, 19.8427, 3.2716, 4.6801, 1.2013,
-10.5358, -23.9165, -23.9042, -23.8356, -23.9010, -23.8193, -23.8113,
-23.8627, -23.8919, -23.7594, -23.8876, -23.8306, -23.8309, -23.8218,
-23.8962, -23.8502, -23.8697, -23.8455, -23.8186, -23.8383, -23.8664,
-23.8439, -23.8413, -23.8480, -23.7054, -23.8808, -23.8536, -23.8435,
-23.7684, -23.7370, -23.8890, -23.7747, -23.8712, -23.7957, -23.9020,
-23.8487],
[-23.1306, -22.9653, -23.0733, 26.3065, 24.2521, 25.9444, 24.8271,
24.3144, 24.7693, 24.1082, 24.7063, 24.3209, 23.6993, 23.7231,
24.5121, 24.3988, 23.3890, 23.3675, 23.2677, 22.5603, 23.1209,
23.1623, 22.5943, 23.0523, 22.7931, 5.9679, 9.2470, 4.5348,
-9.8923, -23.1562, -23.1167, -23.0736, -23.1464, -23.0331, -23.0582,
-23.0821, -23.1101, -23.0115, -23.1267, -23.0533, -23.0692, -23.0486,
-23.1065, -23.0975, -23.1120, -23.0666, -23.0482, -23.0980, -23.0964,
-23.0799, -23.0405, -23.0312, -22.9158, -23.0886, -23.0651, -23.0713,
-22.9998, -22.9620, -23.1208, -22.9704, -23.0561, -23.0216, -23.1095,
-23.0899]]], grad_fn=), last_hidden_state=tensor([[[-0.0059, 0.0067, -0.0009, ..., -0.0074, 0.0060, 0.0054],
[-0.0002, 0.0599, 0.0088, ..., 0.0134, -0.0024, 0.0160],
[-0.0045, 0.0138, 0.0092, ..., 0.0245, -0.0009, 0.0086],
[-0.0108, 0.0288, 0.0163, ..., 0.0118, 0.0085, -0.0099],
[ 0.0129, 0.0155, 0.0128, ..., 0.0285, 0.0105, 0.0154],
[ 0.0023, 0.0294, -0.0348, ..., -0.0042, -0.0060, -0.0029]]],
grad_fn=), hidden_states=None, attentions=None)
environment setup:
packages in environment at /home/xxx/anaconda3/envs/esmc:
Name Version Build Channel
_libgcc_mutex 0.1 main
_openmp_mutex 5.1 1_gnu
anyio 4.10.0 pypi_0 pypi
asttokens 3.0.0 pypi_0 pypi
attrs 25.3.0 pypi_0 pypi
biopython 1.85 pypi_0 pypi
biotite 1.2.0 pypi_0 pypi
biotraj 1.2.2 pypi_0 pypi
brotli 1.1.0 pypi_0 pypi
bzip2 1.0.8 h5eee18b_6
ca-certificates 2025.7.15 h06a4308_0
certifi 2025.8.3 pypi_0 pypi
charset-normalizer 3.4.3 pypi_0 pypi
cloudpathlib 0.21.1 pypi_0 pypi
decorator 5.2.1 pypi_0 pypi
einops 0.8.1 pypi_0 pypi
esm 3.2.1 pypi_0 pypi
exceptiongroup 1.3.0 pypi_0 pypi
executing 2.2.0 pypi_0 pypi
expat 2.7.1 h6a678d5_0
filelock 3.18.0 pypi_0 pypi
fsspec 2025.7.0 pypi_0 pypi
h11 0.16.0 pypi_0 pypi
hf-xet 1.1.7 pypi_0 pypi
httpcore 1.0.9 pypi_0 pypi
httpx 0.28.1 pypi_0 pypi
huggingface-hub 0.34.4 pypi_0 pypi
idna 3.10 pypi_0 pypi
ipython 8.37.0 pypi_0 pypi
jedi 0.19.2 pypi_0 pypi
jinja2 3.1.6 pypi_0 pypi
joblib 1.5.1 pypi_0 pypi
ld_impl_linux-64 2.40 h12ee557_0
libffi 3.4.4 h6a678d5_1
libgcc-ng 11.2.0 h1234567_1
libgomp 11.2.0 h1234567_1
libstdcxx-ng 11.2.0 h1234567_1
libuuid 1.41.5 h5eee18b_0
libxcb 1.17.0 h9b100fa_0
markupsafe 3.0.2 pypi_0 pypi
matplotlib-inline 0.1.7 pypi_0 pypi
mpmath 1.3.0 pypi_0 pypi
msgpack 1.1.1 pypi_0 pypi
msgpack-numpy 0.4.8 pypi_0 pypi
ncurses 6.5 h7934f7d_0
networkx 3.4.2 pypi_0 pypi
numpy 2.2.6 pypi_0 pypi
nvidia-cublas-cu12 12.8.4.1 pypi_0 pypi
nvidia-cuda-cupti-cu12 12.8.90 pypi_0 pypi
nvidia-cuda-nvrtc-cu12 12.8.93 pypi_0 pypi
nvidia-cuda-runtime-cu12 12.8.90 pypi_0 pypi
nvidia-cudnn-cu12 9.10.2.21 pypi_0 pypi
nvidia-cufft-cu12 11.3.3.83 pypi_0 pypi
nvidia-cufile-cu12 1.13.1.3 pypi_0 pypi
nvidia-curand-cu12 10.3.9.90 pypi_0 pypi
nvidia-cusolver-cu12 11.7.3.90 pypi_0 pypi
nvidia-cusparse-cu12 12.5.8.93 pypi_0 pypi
nvidia-cusparselt-cu12 0.7.1 pypi_0 pypi
nvidia-nccl-cu12 2.27.3 pypi_0 pypi
nvidia-nvjitlink-cu12 12.8.93 pypi_0 pypi
nvidia-nvtx-cu12 12.8.90 pypi_0 pypi
openssl 3.0.17 h5eee18b_0
packaging 25.0 pypi_0 pypi
pandas 2.3.1 pypi_0 pypi
parso 0.8.4 pypi_0 pypi
pexpect 4.9.0 pypi_0 pypi
pillow 11.3.0 pypi_0 pypi
pip 25.1 pyhc872135_2
prompt-toolkit 3.0.51 pypi_0 pypi
pthread-stubs 0.3 h0ce48e5_1
ptyprocess 0.7.0 pypi_0 pypi
pure-eval 0.2.3 pypi_0 pypi
pygments 2.19.2 pypi_0 pypi
python 3.10.18 h1a3bd86_0
python-dateutil 2.9.0.post0 pypi_0 pypi
pytz 2025.2 pypi_0 pypi
pyyaml 6.0.2 pypi_0 pypi
readline 8.3 hc2a1206_0
regex 2025.7.34 pypi_0 pypi
requests 2.32.4 pypi_0 pypi
safetensors 0.6.2 pypi_0 pypi
scikit-learn 1.7.1 pypi_0 pypi
scipy 1.15.3 pypi_0 pypi
setuptools 78.1.1 py310h06a4308_0
six 1.17.0 pypi_0 pypi
sniffio 1.3.1 pypi_0 pypi
sqlite 3.50.2 hb25bd0a_1
stack-data 0.6.3 pypi_0 pypi
sympy 1.14.0 pypi_0 pypi
tenacity 9.1.2 pypi_0 pypi
threadpoolctl 3.6.0 pypi_0 pypi
tk 8.6.14 h993c535_1
tokenizers 0.21.4 pypi_0 pypi
torch 2.8.0 pypi_0 pypi
torchtext 0.18.0 pypi_0 pypi
torchvision 0.23.0 pypi_0 pypi
tqdm 4.67.1 pypi_0 pypi
traitlets 5.14.3 pypi_0 pypi
transformers 4.48.1 pypi_0 pypi
triton 3.4.0 pypi_0 pypi
typing-extensions 4.14.1 pypi_0 pypi
tzdata 2025.2 pypi_0 pypi
urllib3 2.5.0 pypi_0 pypi
wcwidth 0.2.13 pypi_0 pypi
wheel 0.45.1 py310h06a4308_0
xorg-libx11 1.8.12 h9b100fa_1
xorg-libxau 1.0.12 h9b100fa_0
xorg-libxdmcp 1.1.5 h9b100fa_0
xorg-xorgproto 2024.1 h5eee18b_1
xz 5.6.4 h5eee18b_1
zlib 1.2.13 h5eee18b_1
zstd 1.5.7.2 pypi_0 pypi
Why ESMC 600M and ESM++ outputs may differ
Hi @yfxf9868 ,
I apologize for not responding to your issue sooner. Did you resolve this issue? My guess is that the native precision of ESM++ is bfloat16 and for ESMC is is float32. The output of torch.allclose() can be very helpful in determining the maximal difference between two very large tensors.
If you have further issues please let me know.
Best,
Logan