Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion cookbook/local/open_generate.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
"\n",
"!pip install py3Dmol\n",
"import py3Dmol\n",
"\n",
"from esm.models.esm3 import ESM3\n",
"from esm.sdk.api import ESMProtein, GenerationConfig\n",
"from esm.utils.structure.protein_chain import ProteinChain"
Expand Down
5 changes: 3 additions & 2 deletions cookbook/local/raw_forwards.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import torch
import torch.nn.functional as F

from esm.pretrained import (
ESM3_function_decoder_v0,
ESM3_sm_open_v0,
Expand All @@ -13,7 +12,9 @@
from esm.tokenization.function_tokenizer import (
InterProQuantizedTokenizer as EsmFunctionTokenizer,
)
from esm.tokenization.sequence_tokenizer import EsmSequenceTokenizer
from esm.tokenization.sequence_tokenizer import (
EsmSequenceTokenizer,
)
from esm.utils.structure.protein_chain import ProteinChain
from esm.utils.types import FunctionAnnotation

Expand Down
1 change: 0 additions & 1 deletion cookbook/snippets/fold_invfold.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import cast

import numpy as np

from esm.sdk.api import (
ESM3InferenceClient,
ESMProtein,
Expand Down
4 changes: 1 addition & 3 deletions cookbook/tutorials/1_esmprotein.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@
"outputs": [],
"source": [
"from biotite.database import rcsb\n",
"\n",
"from esm.sdk.api import ESMProtein\n",
"from esm.utils.structure.protein_chain import ProteinChain\n",
"from esm.utils.types import FunctionAnnotation\n",
Expand Down Expand Up @@ -497,9 +496,8 @@
"# Functions for visualizing InterPro function annotations\n",
"\n",
"from dna_features_viewer import GraphicFeature, GraphicRecord\n",
"from matplotlib import colormaps\n",
"\n",
"from esm.utils.function.interpro import InterPro, InterProEntryType\n",
"from matplotlib import colormaps\n",
"\n",
"\n",
"def visualize_function_annotations(\n",
Expand Down
6 changes: 3 additions & 3 deletions cookbook/tutorials/2_embed.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,18 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Grab a token from [the Forge console](https://forge.evolutionaryscale.ai/console) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories."
"Grab a token from [Forge](https://forge.evolutionaryscale.ai/) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories."
]
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from getpass import getpass\n",
"\n",
"token = getpass(\"Token from Forge console: \")"
"token = getpass(\"Token from Forge: \")"
]
},
{
Expand Down
7 changes: 3 additions & 4 deletions cookbook/tutorials/3_gfp_design.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@
"import matplotlib.pyplot as pl\n",
"import py3Dmol\n",
"import torch\n",
"\n",
"from esm.sdk import client\n",
"from esm.sdk.api import ESMProtein, GenerationConfig\n",
"from esm.utils.structure.protein_chain import ProteinChain"
Expand All @@ -80,18 +79,18 @@
"\n",
"The largest ESM3 (98 billion parameters) was trained with 1.07e24 FLOPs on 2.78 billion proteins and 771 billion unique tokens. To create esmGFP we used the 7 billion parameter variant of ESM3. We'll use this model via the [EvolutionaryScale Forge](https://forge.evolutionaryscale.ai) API.\n",
"\n",
"Grab a token from [the Forge console](https://forge.evolutionaryscale.ai/console) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories.\n"
"Grab a token from [Forge](https://forge.evolutionaryscale.ai/) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories.\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {
"id": "zNrU9Q2SYonX"
},
"outputs": [],
"source": [
"token = getpass(\"Token from Forge console: \")"
"token = getpass(\"Token from Forge: \")"
]
},
{
Expand Down
5 changes: 2 additions & 3 deletions cookbook/tutorials/4_forge_generate.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
"\n",
"!pip install py3Dmol\n",
"import py3Dmol\n",
"\n",
"from esm.sdk import client\n",
"from esm.sdk.api import ESMProtein, GenerationConfig\n",
"from esm.utils.structure.protein_chain import ProteinChain"
Expand All @@ -53,7 +52,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Grab a token from [the Forge console](https://forge.evolutionaryscale.ai/console) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories."
"Grab a token from [Forge](https://forge.evolutionaryscale.ai/) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories."
]
},
{
Expand All @@ -64,7 +63,7 @@
"source": [
"from getpass import getpass\n",
"\n",
"token = getpass(\"Token from Forge console: \")\n",
"token = getpass(\"Token from Forge: \")\n",
"model = client(model=\"esm3-open\", url=\"https://forge.evolutionaryscale.ai\", token=token)"
]
},
Expand Down
3 changes: 1 addition & 2 deletions cookbook/tutorials/5_guided_generation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
"source": [
"import biotite.structure as bs\n",
"import py3Dmol\n",
"\n",
"from esm.sdk.api import ESMProtein, GenerationConfig\n",
"from esm.sdk.experimental import ESM3GuidedDecoding, GuidedDecodingScoringFunction"
]
Expand Down Expand Up @@ -120,7 +119,7 @@
"\n",
"from esm.sdk import client\n",
"\n",
"token = getpass(\"Token from Forge console: \")\n",
"token = getpass(\"Token from Forge: \")\n",
"model = client(\n",
" model=\"esm3-medium-2024-08\", url=\"https://forge.evolutionaryscale.ai\", token=token\n",
")"
Expand Down
3 changes: 2 additions & 1 deletion esm/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
__version__ = "3.2.2"
__version__ = "3.2.2.post2"

5 changes: 4 additions & 1 deletion esm/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
import torch.nn.functional as F
from torch import nn

from esm.layers.rotary import RotaryEmbedding, TritonRotaryEmbedding
from esm.layers.rotary import (
RotaryEmbedding,
TritonRotaryEmbedding,
)

try:
from flash_attn import flash_attn_varlen_qkvpacked_func # type: ignore
Expand Down
9 changes: 7 additions & 2 deletions esm/layers/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,13 @@
import torch.nn as nn
import torch.nn.functional as F

from esm.layers.attention import FlashMultiHeadAttention, MultiHeadAttention
from esm.layers.geom_attention import GeometricReasoningOriginalImpl
from esm.layers.attention import (
FlashMultiHeadAttention,
MultiHeadAttention,
)
from esm.layers.geom_attention import (
GeometricReasoningOriginalImpl,
)
from esm.utils.structure.affine3d import Affine3D


Expand Down
5 changes: 4 additions & 1 deletion esm/layers/structure_proj.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
import torch.nn as nn

from esm.utils.constants.physics import BB_COORDINATES
from esm.utils.structure.affine3d import Affine3D, RotationMatrix
from esm.utils.structure.affine3d import (
Affine3D,
RotationMatrix,
)


class Dim6RotStructureHead(nn.Module):
Expand Down
14 changes: 11 additions & 3 deletions esm/models/esm3.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
from esm.layers.regression_head import RegressionHead
from esm.layers.transformer_stack import TransformerStack
from esm.models.function_decoder import FunctionTokenDecoder
from esm.models.vqvae import StructureTokenDecoder, StructureTokenEncoder
from esm.models.vqvae import (
StructureTokenDecoder,
StructureTokenEncoder,
)
from esm.sdk.api import (
ESM3InferenceClient,
ESMProtein,
Expand All @@ -29,7 +32,10 @@
from esm.tokenization import TokenizerCollectionProtocol
from esm.utils import encoding
from esm.utils.constants import esm3 as C
from esm.utils.constants.models import ESM3_OPEN_SMALL, normalize_model_name
from esm.utils.constants.models import (
ESM3_OPEN_SMALL,
normalize_model_name,
)
from esm.utils.decoding import decode_protein_tensor
from esm.utils.generation import (
_batch_forward,
Expand All @@ -44,7 +50,9 @@
get_default_sampling_config,
validate_sampling_config,
)
from esm.utils.structure.affine3d import build_affine3d_from_coordinates
from esm.utils.structure.affine3d import (
build_affine3d_from_coordinates,
)


@dataclass
Expand Down
4 changes: 3 additions & 1 deletion esm/models/function_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@

from esm.layers.regression_head import RegressionHead
from esm.layers.transformer_stack import TransformerStack
from esm.tokenization.function_tokenizer import InterProQuantizedTokenizer
from esm.tokenization.function_tokenizer import (
InterProQuantizedTokenizer,
)
from esm.utils.constants import esm3 as C
from esm.utils.misc import merge_annotations, merge_ranges
from esm.utils.types import FunctionAnnotation
Expand Down
5 changes: 4 additions & 1 deletion esm/models/vqvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
from esm.layers.transformer_stack import TransformerStack
from esm.utils.constants import esm3 as C
from esm.utils.misc import knn_graph
from esm.utils.structure.affine3d import Affine3D, build_affine3d_from_coordinates
from esm.utils.structure.affine3d import (
Affine3D,
build_affine3d_from_coordinates,
)
from esm.utils.structure.predicted_aligned_error import (
compute_predicted_aligned_error,
compute_tm,
Expand Down
10 changes: 8 additions & 2 deletions esm/pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,14 @@
from esm.models.esm3 import ESM3
from esm.models.esmc import ESMC
from esm.models.function_decoder import FunctionTokenDecoder
from esm.models.vqvae import StructureTokenDecoder, StructureTokenEncoder
from esm.tokenization import get_esm3_model_tokenizers, get_esmc_model_tokenizers
from esm.models.vqvae import (
StructureTokenDecoder,
StructureTokenEncoder,
)
from esm.tokenization import (
get_esm3_model_tokenizers,
get_esmc_model_tokenizers,
)
from esm.utils.constants.esm3 import data_root
from esm.utils.constants.models import (
ESM3_FUNCTION_DECODER_V0,
Expand Down
54 changes: 42 additions & 12 deletions esm/sdk/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,27 @@

from abc import ABC
from copy import deepcopy
from typing import Sequence
from typing import List, Sequence

import attr
import torch
from attr import asdict, define

import esm.utils.constants.api as C
from esm.tokenization import TokenizerCollectionProtocol, get_esm3_model_tokenizers
from esm.tokenization import (
TokenizerCollectionProtocol,
get_esm3_model_tokenizers,
)
from esm.utils import encoding
from esm.utils.constants.models import ESM3_OPEN_SMALL
from esm.utils.misc import get_chainbreak_boundaries_from_sequence
from esm.utils.misc import (
get_chainbreak_boundaries_from_sequence,
)
from esm.utils.structure.protein_chain import ProteinChain
from esm.utils.structure.protein_complex import SINGLE_LETTER_CHAIN_IDS, ProteinComplex
from esm.utils.structure.protein_complex import (
SINGLE_LETTER_CHAIN_IDS,
ProteinComplex,
)
from esm.utils.types import FunctionAnnotation, PathOrBuffer


Expand All @@ -35,6 +43,7 @@ class ESMProtein(ProteinType):
plddt: torch.Tensor | None = None
ptm: torch.Tensor | None = None


# When calling EvolutionaryScale API, use this flag to disclose any
# sequences that may potentially have concerns.
# Such sequences may not go through standard safety filter for approved users.
Expand Down Expand Up @@ -148,20 +157,43 @@ def to_protein_complex(
gt_chains = list(copy_annotations_from_ground_truth.chain_iter())
else:
gt_chains = None

# Expand pLDDT to match sequence length if needed, inserting NaN at chain breaks
# This handles the case where the server doesn't include chain breaks in pLDDT
# We should fix this in the server side.
if self.plddt is not None and len(self.plddt) != len(self.sequence):
# Only expand if there's a mismatch (likely due to chain breaks)
if "|" in self.sequence:
# Create expanded pLDDT with NaN at chain break positions
expanded_plddt = torch.full((len(self.sequence),), float("nan"))
plddt_idx = 0
for i, aa in enumerate(self.sequence):
if aa != "|":
if plddt_idx < len(self.plddt):
expanded_plddt[i] = self.plddt[plddt_idx]
plddt_idx += 1
plddt = expanded_plddt
else:
# Mismatch but no chain breaks - shouldn't happen but preserve original
plddt = self.plddt
else:
plddt = self.plddt

pred_chains = []
for i, (start, end) in enumerate(chain_boundaries):
if i >= len(SINGLE_LETTER_CHAIN_IDS):
raise ValueError(
f"Too many chains to convert to ProteinComplex. The maximum number of chains is {len(SINGLE_LETTER_CHAIN_IDS)}"
)

pred_chain = ProteinChain.from_atom37(
atom37_positions=coords[start:end],
sequence=self.sequence[start:end],
chain_id=gt_chains[i].chain_id
if gt_chains is not None
else SINGLE_LETTER_CHAIN_IDS[i],
entity_id=gt_chains[i].entity_id if gt_chains is not None else None,
confidence=self.plddt[start:end] if self.plddt is not None else None,
confidence=plddt[start:end] if plddt is not None else None,
)
pred_chains.append(pred_chain)
return ProteinComplex.from_chains(pred_chains)
Expand Down Expand Up @@ -298,19 +330,14 @@ def use_generative_unmasking_strategy(self):
self.temperature_annealing = True


@define
class MSA:
# Paired MSA sequences.
# One would typically compute these using, for example, ColabFold.
sequences: list[str]


@define
class InverseFoldingConfig:
invalid_ids: Sequence[int] = []
temperature: float = 1.0




## Low Level Endpoint Types
@define
class SamplingTrackConfig:
Expand Down Expand Up @@ -375,6 +402,9 @@ class LogitsConfig:
ith_hidden_layer: int = -1





@define
class LogitsOutput:
logits: ForwardTrackData | None = None
Expand Down
Loading
Loading