Skip to content

T5 DeepSpeed Interface is actually slower than Pytorch #2242

@Oxi84

Description

@Oxi84

I expected that DeepSpeed interface would be faster than Pytorch but it is actually around 10 percent slower.
Pytorch does it in 0.28 secs and DeepSpeed in around 0.30.

Is there anything I do wrong, mateb i should use something else than replace_method='auto',?

import torch
from transformers import T5ForConditionalGeneration,T5Tokenizer,T5TokenizerFast
import time
print("begin loading")

model1a = T5ForConditionalGeneration.from_pretrained("t5-base")
tokenizer1 = T5TokenizerFast.from_pretrained('t5-base', cache_dir="/root/Desktop/model_cache/")  


sentence_list = []
for m in range(30):
    sentence_list.append("I like turtles because they are slow.")

text = []
for sentence in sentence_list: 
        text.append("slow: " + sentence + " </s>")


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")                  
encoding = tokenizer1.batch_encode_plus(text,padding='longest', return_tensors="pt")
input_ids, attention_masks = encoding["input_ids"].to(device), encoding["attention_mask"].to(device)
    

import deepspeed
#from transformers.models.t5.modeling_t5 import T5Block
ds_engine = deepspeed.init_inference(model1a,
                             mp_size=1,
                             dtype=torch.half,
                             replace_method='auto',
                             replace_with_kernel_inject=True)
model1ab = ds_engine.module
#model1a.half()  
#model1a.to("cuda")
#model1a.eval()          

if 1==1:
        begin = time.time() 
        with torch.no_grad():
            beam_outputs = model1ab.generate(
                input_ids=input_ids, attention_mask=attention_masks,
                do_sample=False,
                num_beams=4,
                max_length=128,
                num_return_sequences=4
                #output_scores=True
                )  
                          
        print("it took model",-begin + time.time() )    
        final_outputs = tokenizer1.batch_decode(beam_outputs, skip_special_tokens=True)  
         
        print("final_outputs",final_outputs)
        print("it took total",-begin + time.time() ) 
        #print("it took0",-begin0 + time.time() )   
        print("###") 

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions