11"""Fold REST API interface for making HTTP calls to our fold backend."""
22
33import io
4- from typing import Literal
4+ from typing import TYPE_CHECKING , Literal
55
66import numpy as np
77from pydantic import TypeAdapter
1212
1313from .schemas import FoldJob , FoldMetadata
1414
15+ if TYPE_CHECKING :
16+ import pandas as pd
17+
1518PATH_PREFIX = "v1/fold"
1619
1720
@@ -160,8 +163,8 @@ def fold_get_complex_result(
160163def fold_get_complex_extra_result (
161164 session : APISession ,
162165 job_id : str ,
163- key : Literal ["pae" , "pde" , "plddt" , "confidence" , "affinity" ],
164- ) -> np .ndarray | list [dict ]:
166+ key : Literal ["pae" , "pde" , "plddt" , "confidence" , "affinity" , "score" , "metrics" ],
167+ ) -> " np.ndarray | list[dict] | pd.DataFrame" :
165168 """
166169 Get extra result for a complex from the request ID.
167170
@@ -183,6 +186,10 @@ def fold_get_complex_extra_result(
183186 formatter = lambda response : np .load (io .BytesIO (response .content ))
184187 elif key in {"confidence" , "affinity" }:
185188 formatter = lambda response : response .json ()
189+ elif key in {"score" , "metrics" }:
190+ import pandas as pd
191+
192+ formatter = lambda response : pd .read_csv (io .StringIO (response .content .decode ()))
186193 else :
187194 raise ValueError (f"Unexpected key: { key } " )
188195 endpoint = PATH_PREFIX + f"/{ job_id } /complex/{ key } "
@@ -194,7 +201,7 @@ def fold_get_complex_extra_result(
194201 if e .status_code == 400 and key == "affinity" :
195202 raise ValueError ("affinity not found for request" ) from None
196203 raise e
197- output : np . ndarray | list [ dict ] = formatter (response )
204+ output = formatter (response )
198205 return output
199206
200207
@@ -254,34 +261,11 @@ def fold_models_post(
254261 sequences = kwargs ["sequences" ]
255262 # NOTE we are handling the boltz form here too
256263 sequences = [s .decode () if isinstance (s , bytes ) else s for s in sequences ]
257- body ["sequences" ] = sequences
258- if kwargs .get ("msa_id" ):
259- body ["msa_id" ] = kwargs ["msa_id" ]
260- if kwargs .get ("num_recycles" ):
261- body ["num_recycles" ] = kwargs ["num_recycles" ]
262- if kwargs .get ("num_models" ):
263- body ["num_models" ] = kwargs ["num_models" ]
264- if kwargs .get ("num_relax" ):
265- body ["num_relax" ] = kwargs ["num_relax" ]
266- if kwargs .get ("use_potentials" ):
267- body ["use_potentials" ] = kwargs ["use_potentials" ]
268- # boltz
269- if kwargs .get ("diffusion_samples" ):
270- body ["diffusion_samples" ] = kwargs ["diffusion_samples" ]
271- if kwargs .get ("recycling_steps" ):
272- body ["recycling_steps" ] = kwargs ["recycling_steps" ]
273- if kwargs .get ("sampling_steps" ):
274- body ["sampling_steps" ] = kwargs ["sampling_steps" ]
275- if kwargs .get ("step_scale" ):
276- body ["step_scale" ] = kwargs ["step_scale" ]
277- if kwargs .get ("constraints" ):
278- body ["constraints" ] = kwargs ["constraints" ]
279- if kwargs .get ("templates" ):
280- body ["templates" ] = kwargs ["templates" ]
281- if kwargs .get ("properties" ):
282- body ["properties" ] = kwargs ["properties" ]
283- if kwargs .get ("method" ):
284- body ["method" ] = kwargs ["method" ]
264+ kwargs ["sequences" ] = sequences
265+ # add non-None args - note this doesnt affect msa_id which is nested
266+ for k , v in kwargs .items ():
267+ if v is not None :
268+ body [k ] = v
285269
286270 response = session .post (endpoint , json = body )
287271 return FoldJob .model_validate (response .json ())
0 commit comments