11import warnings
22from collections import defaultdict
3- from typing import List , Optional , Tuple , Union
4-
5- from emmet .core .thermo import ThermoDoc
6- from pymatgen .analysis .phase_diagram import PhaseDiagram
7-
3+ from typing import Optional , List , Tuple , Union
84from mp_api .client .core import BaseRester
95from mp_api .client .core .utils import validate_ids
6+ from emmet .core .thermo import ThermoDoc , ThermoType
7+ from pymatgen .analysis .phase_diagram import PhaseDiagram
108
119
1210class ThermoRester (BaseRester [ThermoDoc ]):
@@ -40,6 +38,8 @@ def search(
4038 is_stable : Optional [bool ] = None ,
4139 material_ids : Optional [List [str ]] = None ,
4240 num_elements : Optional [Tuple [int , int ]] = None ,
41+ thermo_ids : Optional [List [str ]] = None ,
42+ thermo_types : Optional [List [ThermoType ]] = None ,
4343 total_energy : Optional [Tuple [float , float ]] = None ,
4444 uncorrected_energy : Optional [Tuple [float , float ]] = None ,
4545 sort_fields : Optional [List [str ]] = None ,
@@ -63,6 +63,9 @@ def search(
6363 (e.g., [Fe2O3, ABO3]).
6464 is_stable (bool): Whether the material is stable.
6565 material_ids (List[str]): List of Materials Project IDs to return data for.
66+ thermo_ids (List[str]): List of thermo IDs to return data for. This is a combination of the Materials
67+ Project ID and thermo type (e.g. mp-149_GGA_GGA+U).
68+ thermo_types (List[ThermoType]): List of thermo types to return data for (e.g. ThermoType.GGA_GGA_U).
6669 num_elements (Tuple[int,int]): Minimum and maximum number of elements in the material to consider.
6770 total_energy (Tuple[float,float]): Minimum and maximum corrected total energy in eV/atom to consider.
6871 uncorrected_energy (Tuple[float,float]): Minimum and maximum uncorrected total
@@ -95,6 +98,14 @@ def search(
9598 if material_ids :
9699 query_params .update ({"material_ids" : "," .join (validate_ids (material_ids ))})
97100
101+ if thermo_ids :
102+ query_params .update ({"thermo_ids" : "," .join (validate_ids (thermo_ids ))})
103+
104+ if thermo_types :
105+ query_params .update (
106+ {"thermo_types" : "," .join ([t .value for t in thermo_types ])}
107+ )
108+
98109 if num_elements :
99110 if isinstance (num_elements , int ):
100111 num_elements = (num_elements , num_elements )
@@ -141,19 +152,23 @@ def search(
141152 ** query_params ,
142153 )
143154
144- def get_phase_diagram_from_chemsys (self , chemsys : str ) -> PhaseDiagram :
155+ def get_phase_diagram_from_chemsys (
156+ self , chemsys : str , thermo_type : ThermoType = ThermoType .GGA_GGA_U
157+ ) -> PhaseDiagram :
145158 """
146159 Get a pre-computed phase diagram for a given chemsys.
147160
148161 Arguments:
149- material_id (str): Materials project ID
162+ chemsys (str): A chemical system (e.g. Li-Fe-O)
163+ thermo_type (ThermoType): The thermo type for the phase diagram.
164+ Defaults to ThermoType.GGA_GGA_U.
150165 Returns:
151166 phase_diagram (PhaseDiagram): Pymatgen phase diagram object.
152167 """
153-
168+ phase_diagram_id = f" { chemsys } _ { thermo_type . value } "
154169 response = self ._query_resource (
155170 fields = ["phase_diagram" ],
156- suburl = f"phase_diagram/{ chemsys } " ,
171+ suburl = f"phase_diagram/{ phase_diagram_id } " ,
157172 use_document_model = False ,
158173 num_chunks = 1 ,
159174 chunk_size = 1 ,
0 commit comments