Deadlock/Hanging/Stagnation/Stall when using BERT model in parallel with Torch/PyTorch
Hey guys,
I'm trying to parallelize and queue machine learning jobs and have run into a problem. Occasionally, the BERT Question and Answer model will permanently hang when run in multiple workers at the same time. Eventually the job is force-stopped due to timeout. I can't figure out why this is happening.
My guess is the Bert model instances, despite being in separate workers, are somehow communicating with each other or relying on the same memory, causing a deadlock.
The code
run_worker.py
# -*- coding: utf-8 -*-
from __future__ import (absolute_import, division, print_function,
unicode_literals)
from rq import Connection, Queue, SimpleWorker
from worker_test import *
if __name__ == '__main__':
# Tell rq what Redis connection to use
with Connection():
job_init()
q = Queue()
SimpleWorker(q).work()
run_test_job.py
from redis import Redis
from rq import Queue
import time
from worker_test import run_bert
redis_conn = Redis()
async_results = {}
q = Queue(connection=redis_conn)
start_time = time.time()
for i in range(30):
async_results[i] = q.enqueue(run_bert, "Who?", "John eats a lot of pizza")
done = False
complete = 0
while not done:
done = True
count = 0
for i in range(len(async_results)):
result = async_results[i].return_value
if result is None:
done = False
else:
count += 1
if count > complete:
complete = count
print(complete)
time.sleep(1)
end_time = time.time()
print(end_time-start_time)
worker_test.py
import torch
from transformers import BertForQuestionAnswering, BertTokenizer, pipeline
def job_init():
global nlp_pipe
global model
global tokenizer
model = BertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')
tokenizer = BertTokenizer.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')
nlp_pipe = pipeline('question-answering', model=model, tokenizer=tokenizer)
def run_bert(question, text):
output = nlp_pipe({
'question': question,
'context': text
})
output = nlp_pipe({
'question': question,
'context': text
})
output = nlp_pipe({
'question': question,
'context': text
})
output = nlp_pipe({
'question': question,
'context': text
})
print(output)
return output
The error
13:04:14 default: worker_test.run_bert('Who?', 'John eats a lot of pizza') (f4eab269-6d4c-4273-9084-1ab031b6d01c)
13:04:19 Warm shut down requested
13:04:23 Warm shut down requested
13:04:32 Warm shut down requested
13:07:14 Traceback (most recent call last):
File "/home/you/.local/lib/python3.9/site-packages/rq/worker.py", line 1061, in perform_job
rv = job.perform()
File "/home/you/.local/lib/python3.9/site-packages/rq/job.py", line 821, in perform
self._result = self._execute()
File "/home/you/.local/lib/python3.9/site-packages/rq/job.py", line 844, in _execute
result = self.func(*self.args, **self.kwargs)
File "/home/you/Downloads/rq-master/examples/worker_test.py", line 31, in run_bert
output = nlp_pipe({
File "/home/you/.local/lib/python3.9/site-packages/transformers/pipelines/question_answering.py", line 248, in __call__
return super().__call__(examples[0], **kwargs)
File "/home/you/.local/lib/python3.9/site-packages/transformers/pipelines/base.py", line 915, in __call__
return self.run_single(inputs, preprocess_params, forward_params, postprocess_params)
File "/home/you/.local/lib/python3.9/site-packages/transformers/pipelines/base.py", line 921, in run_single
model_inputs = self.preprocess(inputs, **preprocess_params)
File "/home/you/.local/lib/python3.9/site-packages/transformers/pipelines/question_answering.py", line 259, in preprocess
features = squad_convert_examples_to_features(
File "/home/you/.local/lib/python3.9/site-packages/transformers/data/processors/squad.py", line 377, in squad_convert_examples_to_features
features = list(
File "/usr/lib/python3.9/multiprocessing/pool.py", line 736, in __exit__
self.terminate()
File "/usr/lib/python3.9/multiprocessing/pool.py", line 654, in terminate
self._terminate()
File "/usr/lib/python3.9/multiprocessing/util.py", line 224, in __call__
res = self._callback(*self._args, **self._kwargs)
File "/usr/lib/python3.9/multiprocessing/pool.py", line 729, in _terminate_pool
p.join()
File "/usr/lib/python3.9/multiprocessing/process.py", line 149, in join
res = self._popen.wait(timeout)
File "/usr/lib/python3.9/multiprocessing/popen_fork.py", line 43, in wait
return self.poll(os.WNOHANG if timeout == 0.0 else 0)
File "/usr/lib/python3.9/multiprocessing/popen_fork.py", line 27, in poll
pid, sts = os.waitpid(self.pid, flag)
File "/home/you/.local/lib/python3.9/site-packages/rq/timeouts.py", line 63, in handle_death_penalty
raise self._exception('Task exceeded maximum timeout value '
rq.timeouts.JobTimeoutException: Task exceeded maximum timeout value (180 seconds)
Traceback (most recent call last):
File "/home/you/.local/lib/python3.9/site-packages/rq/worker.py", line 1061, in perform_job
rv = job.perform()
File "/home/you/.local/lib/python3.9/site-packages/rq/job.py", line 821, in perform
self._result = self._execute()
File "/home/you/.local/lib/python3.9/site-packages/rq/job.py", line 844, in _execute
result = self.func(*self.args, **self.kwargs)
File "/home/you/Downloads/rq-master/examples/worker_test.py", line 31, in run_bert
output = nlp_pipe({
File "/home/you/.local/lib/python3.9/site-packages/transformers/pipelines/question_answering.py", line 248, in __call__
return super().__call__(examples[0], **kwargs)
File "/home/you/.local/lib/python3.9/site-packages/transformers/pipelines/base.py", line 915, in __call__
return self.run_single(inputs, preprocess_params, forward_params, postprocess_params)
File "/home/you/.local/lib/python3.9/site-packages/transformers/pipelines/base.py", line 921, in run_single
model_inputs = self.preprocess(inputs, **preprocess_params)
File "/home/you/.local/lib/python3.9/site-packages/transformers/pipelines/question_answering.py", line 259, in preprocess
features = squad_convert_examples_to_features(
File "/home/you/.local/lib/python3.9/site-packages/transformers/data/processors/squad.py", line 377, in squad_convert_examples_to_features
features = list(
File "/usr/lib/python3.9/multiprocessing/pool.py", line 736, in __exit__
self.terminate()
File "/usr/lib/python3.9/multiprocessing/pool.py", line 654, in terminate
self._terminate()
File "/usr/lib/python3.9/multiprocessing/util.py", line 224, in __call__
res = self._callback(*self._args, **self._kwargs)
File "/usr/lib/python3.9/multiprocessing/pool.py", line 729, in _terminate_pool
p.join()
File "/usr/lib/python3.9/multiprocessing/process.py", line 149, in join
res = self._popen.wait(timeout)
File "/usr/lib/python3.9/multiprocessing/popen_fork.py", line 43, in wait
return self.poll(os.WNOHANG if timeout == 0.0 else 0)
File "/usr/lib/python3.9/multiprocessing/popen_fork.py", line 27, in poll
pid, sts = os.waitpid(self.pid, flag)
File "/home/you/.local/lib/python3.9/site-packages/rq/timeouts.py", line 63, in handle_death_penalty
raise self._exception('Task exceeded maximum timeout value '
rq.timeouts.JobTimeoutException: Task exceeded maximum timeout value (180 seconds)
> python run_test_job.py
1
2
3
4
5
6
8
9
10
11
12
13
14
15
16
17
18
19
20
21
23
24
26
27
28
The above console log shows what happens to run_test_job.py when the hang happens. It should print out 29 and 30 for the 29th and 30th job completions and finish execution, but doesn't because of the timeout.
How to reproduce
I've been able to reproduce this bug across two different Linux machines with different Linux distributions and Python versions.
System 1 OS: Red Hat Linux Python: 3.6.13 Packages: - torch==1.9.0 - transformers==4.11.3
System 2 OS: Manjaro Linux Python: 3.9.9 Packages: - torch==1.9.0 - transformers==4.11.3
I haven't tried newer versions of torch or transformers yet.
Reproduction requires triggering the deadlock on one of the workers. Appearances seem to increase as the amount of workers run increases. Appearances seem to happen more often at the beginning of the worker's lifetime.
I've been able to reproduce approximately 80% of the time with 4 workers on 4 core machines immediately after starting the workers.
I've been able to reproduce this with the latest versions of torch and transformers.
torch==1.10.1 transformers==4.15.0
To confirm this is a problem with rq and not the Bert model itself, I ran 4 Python instances in parallel executing the code below for 10 minutes and was not able to reproduce the deadlock. Meaning, something is likely introduced by rq which causes this bug.
The longest iteration was for 51 seconds which isn't close to the timeout time of rq.
import torch
from transformers import BertForQuestionAnswering, BertTokenizer, pipeline
from datetime import datetime
def job_init():
global nlp_pipe
global model
global tokenizer
model = BertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')
tokenizer = BertTokenizer.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')
nlp_pipe = pipeline('question-answering', model=model, tokenizer=tokenizer)
def run_bert(question, text):
print('first')
output = nlp_pipe({
'question': question,
'context': text
})
print('second')
output = nlp_pipe({
'question': question,
'context': text
})
print('third')
output = nlp_pipe({
'question': question,
'context': text
})
print('forth')
output = nlp_pipe({
'question': question,
'context': text
})
print(output)
return output
if __name__ == '__main__':
job_init()
while True:
output = run_bert("Who?", "John eats a lot of pizza")
now = datetime.now().time()
print('[' + str(now) + '] Next iter')
The problem seems to be reduced with torch.set_num_threads(1) indicating workers could be communicating or sharing memory somehow.
@xNul did you try to run this with the regular Worker instead of the SimpleWorker? I'm not familiar with torch, maybe each job is creating many threads? I'm not sure if the SimpleWorker is thread safe indeed. @selwin do you know?
@lowercase00 I did, but I believe it made no difference. My memory is hazy now that it's been over a year. I ended up not using Redis at all and going a different architectural route even though Redis was producing better results when it didn't deadlock. It should still be possible to reproduce with the given code and check that way though.
@xNul I was thinking of also using Redis to load multiple models but now that I read your issue, I was wondering if you could let me know which architecture you went for instead?
@PedroMTQ I just went the multiprocessing route and spun up multiple processes to do things in parallel instead. There's a way to share the model across all those processes without duplicating memory. I did that too.
Thanks @xNul that was my initial approach but I was hoping RQ would streamline it a bit more; so I guess I will go back to it.
This is a bit off topic but hopefully you have tried this:
I was hoping to still use RQ to serve as the API for the user to receive/send data, i.e., have the model(s) running with multiprocessing in the background and a single RQ worker just getting user requests. Having an RQ worker dealing with the queue priorities would be quite useful.
My plan was to have a master process which then spawns children processes (which are hosting identical models), but I'm struggling to figure out how to have queue.enqueue interact with this master process.
Any ideas?