Skip to content

hidet for WhisperSpeech #368

@mehdi-gital

Description

@mehdi-gital

Hi team,
I'm trying to get hidet to optimize the two models on lines 15, 16 in:
https://github.com/collabora/WhisperSpeech/blob/main/whisperspeech/pipeline.py
(more importantly SADelARTransformer at line 16)
I'm working with a small gpu (MX150). When I use hidet as backend of the computation graph, it works without any errors but it doesn't help with the time at all. Is there anything I could be doing to fix this?

I'm testing through https://github.com/collabora/WhisperSpeech/blob/main/Inference%20example.ipynb and below is the modified code I'm using.

Thanks

class Pipeline:

    def __init__(self):
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        hidet_ = True
        if hidet_:
            print('with hidet')
            print(device)
            self.t2s = torch.compile(TSARTransformer.load_model().to(device), backend='hidet') 
            self.s2a = torch.compile(SADelARTransformer.load_model().to(device), backend='hidet')
        else:
            print('without hidet')
            self.t2s = TSARTransformer.load_model().to(device)
            self.s2a = SADelARTransformer.load_model().to(device)
        self.vocoder = Vocoder()

    def generate_atoks(self, text, speaker="8699"):
        text = text.replace("\n", " ")
        start = time.time()
        stoks = self.t2s.generate(text, cps=14)
        end = time.time()
        print('t2s', end - start)

        start = time.time()
        atoks = self.s2a.generate(stoks, [speaker])
        end = time.time()
        print('s2a', end - start)

        return atoks
        
    def generate(self, text, speaker="8699"):
        return self.vocoder.decode(self.generate_atoks(text, speaker))
    
    def generate_to_file(self, fname, text, speaker="8699"):
        self.vocoder.decode_to_file(fname, self.generate_atoks(text, speaker))
        
    def generate_to_notebook(self, text, speaker="8699"):
        start = time.time()
        atokz = self.generate_atoks(text, speaker)
        end = time.time()
        print('generate_atoks(text, speaker) time', end - start)

        start = time.time()
        self.vocoder.decode_to_notebook(atokz)
        end = time.time()
        print('vocoder.decode_to_notebook(atokz)', end - start)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions