11#!/usr/bin/env python3
2- from subprocess import check_call
2+ import shutil
3+ import sys
34from pathlib import Path
5+ from subprocess import check_call
46from tempfile import TemporaryDirectory
57from typing import Optional
6- import sys
7- import shutil
8+
89SCRIPT_DIR = Path (__file__ ).parent
910REPO_DIR = SCRIPT_DIR .parent .parent
1011
12+
1113def read_triton_pin () -> str :
1214 with open (REPO_DIR / ".ci" / "docker" / "ci_commit_pins" / "triton.txt" ) as f :
1315 return f .read ().strip ()
@@ -19,7 +21,7 @@ def read_triton_version() -> str:
1921
2022
2123def check_and_replace (inp : str , src : str , dst : str ) -> str :
22- """ Checks that `src` can be found in `input` and replaces it with `dst` """
24+ """Checks that `src` can be found in `input` and replaces it with `dst`"""
2325 if src not in inp :
2426 raise RuntimeError (f"Can't find ${ src } in the input" )
2527 return inp .replace (src , dst )
@@ -29,9 +31,11 @@ def patch_setup_py(path: Path, *, version: str, name: str = "triton") -> None:
2931 with open (path ) as f :
3032 orig = f .read ()
3133 # Replace name
32- orig = check_and_replace (orig , " name=\ " triton\" ," , f" name=\ "{ name } \" ," )
34+ orig = check_and_replace (orig , ' name="triton",' , f' name="{ name } ",' )
3335 # Replace version
34- orig = check_and_replace (orig , f"version=\" { read_triton_version ()} \" ," , f"version=\" { version } \" ," )
36+ orig = check_and_replace (
37+ orig , f'version="{ read_triton_version ()} ",' , f'version="{ version } ",'
38+ )
3539 with open (path , "w" ) as f :
3640 f .write (orig )
3741
@@ -40,39 +44,81 @@ def patch_init_py(path: Path, *, version: str) -> None:
4044 with open (path ) as f :
4145 orig = f .read ()
4246 # Replace version
43- orig = check_and_replace (orig , f"__version__ = '{ read_triton_version ()} '" , f"__version__ = \" { version } \" " )
47+ orig = check_and_replace (
48+ orig , f"__version__ = '{ read_triton_version ()} '" , f'__version__ = "{ version } "'
49+ )
4450 with open (path , "w" ) as f :
4551 f .write (orig )
4652
4753
48- def build_triton (* , version : str , commit_hash : str , build_conda : bool = False , py_version : Optional [str ] = None ) -> Path :
54+ def build_triton (
55+ * ,
56+ version : str ,
57+ commit_hash : str ,
58+ build_conda : bool = False ,
59+ py_version : Optional [str ] = None ,
60+ ) -> Path :
4961 with TemporaryDirectory () as tmpdir :
5062 triton_basedir = Path (tmpdir ) / "triton"
5163 triton_pythondir = triton_basedir / "python"
5264 check_call (["git" , "clone" , "https://github.com/openai/triton" ], cwd = tmpdir )
5365 check_call (["git" , "checkout" , commit_hash ], cwd = triton_basedir )
5466 if build_conda :
5567 with open (triton_basedir / "meta.yaml" , "w" ) as meta :
56- print (f"package:\n name: torchtriton\n version: { version } +{ commit_hash [:10 ]} \n " , file = meta )
68+ print (
69+ f"package:\n name: torchtriton\n version: { version } +{ commit_hash [:10 ]} \n " ,
70+ file = meta ,
71+ )
5772 print ("source:\n path: .\n " , file = meta )
58- print ("build:\n string: py{{py}}\n number: 1\n script: cd python; "
59- "python setup.py install --single-version-externally-managed --record=record.txt\n " , file = meta )
60- print ("requirements:\n host:\n - python\n - setuptools\n run:\n - python\n "
61- " - filelock\n - pytorch\n " , file = meta )
62- print ("about:\n home: https://github.com/openai/triton\n license: MIT\n summary:"
63- " 'A language and compiler for custom Deep Learning operation'" , file = meta )
64-
65- patch_init_py (triton_pythondir / "triton" / "__init__.py" , version = f"{ version } +{ commit_hash [:10 ]} " )
73+ print (
74+ "build:\n string: py{{py}}\n number: 1\n script: cd python; "
75+ "python setup.py install --single-version-externally-managed --record=record.txt\n " ,
76+ file = meta ,
77+ )
78+ print (
79+ "requirements:\n host:\n - python\n - setuptools\n run:\n - python\n "
80+ " - filelock\n - pytorch\n " ,
81+ file = meta ,
82+ )
83+ print (
84+ "about:\n home: https://github.com/openai/triton\n license: MIT\n summary:"
85+ " 'A language and compiler for custom Deep Learning operation'" ,
86+ file = meta ,
87+ )
88+
89+ patch_init_py (
90+ triton_pythondir / "triton" / "__init__.py" ,
91+ version = f"{ version } +{ commit_hash [:10 ]} " ,
92+ )
6693 if py_version is None :
6794 py_version = f"{ sys .version_info .major } .{ sys .version_info .minor } "
68- check_call (["conda" , "build" , "--python" , py_version ,
69- "-c" , "pytorch-nightly" , "--output-folder" , tmpdir , "." ], cwd = triton_basedir )
95+ check_call (
96+ [
97+ "conda" ,
98+ "build" ,
99+ "--python" ,
100+ py_version ,
101+ "-c" ,
102+ "pytorch-nightly" ,
103+ "--output-folder" ,
104+ tmpdir ,
105+ "." ,
106+ ],
107+ cwd = triton_basedir ,
108+ )
70109 conda_path = list (Path (tmpdir ).glob ("linux-64/torchtriton*.bz2" ))[0 ]
71110 shutil .copy (conda_path , Path .cwd ())
72111 return Path .cwd () / conda_path .name
73112
74- patch_setup_py (triton_pythondir / "setup.py" , name = "pytorch-triton" , version = f"{ version } +{ commit_hash [:10 ]} " )
75- patch_init_py (triton_pythondir / "triton" / "__init__.py" , version = f"{ version } +{ commit_hash [:10 ]} " )
113+ patch_setup_py (
114+ triton_pythondir / "setup.py" ,
115+ name = "pytorch-triton" ,
116+ version = f"{ version } +{ commit_hash [:10 ]} " ,
117+ )
118+ patch_init_py (
119+ triton_pythondir / "triton" / "__init__.py" ,
120+ version = f"{ version } +{ commit_hash [:10 ]} " ,
121+ )
76122 check_call ([sys .executable , "setup.py" , "bdist_wheel" ], cwd = triton_pythondir )
77123 whl_path = list ((triton_pythondir / "dist" ).glob ("*.whl" ))[0 ]
78124 shutil .copy (whl_path , Path .cwd ())
@@ -81,16 +127,19 @@ def build_triton(*, version: str, commit_hash: str, build_conda: bool = False, p
81127
82128def main () -> None :
83129 from argparse import ArgumentParser
130+
84131 parser = ArgumentParser ("Build Triton binaries" )
85132 parser .add_argument ("--build-conda" , action = "store_true" )
86133 parser .add_argument ("--py-version" , type = str )
87134 parser .add_argument ("--commit-hash" , type = str , default = read_triton_pin ())
88135 parser .add_argument ("--triton-version" , type = str , default = read_triton_version ())
89136 args = parser .parse_args ()
90- build_triton (commit_hash = args .commit_hash ,
91- version = args .triton_version ,
92- build_conda = args .build_conda ,
93- py_version = args .py_version )
137+ build_triton (
138+ commit_hash = args .commit_hash ,
139+ version = args .triton_version ,
140+ build_conda = args .build_conda ,
141+ py_version = args .py_version ,
142+ )
94143
95144
96145if __name__ == "__main__" :
0 commit comments