Fails when using multi-threading and CUDA device. SOLVED
#3
by
CoderCowMoo
- opened
Quan mate, I've spent about 10-20 hours looking through the modelling file, the pytorch issues and code, the transformers documentation and code, to figure out why the example code in the README.md doesn't work in a gradio demo.
Turns out, Gradio uses separate threads to execute functions tied to inputs and outputs.
Also turns out, torch.set_default_device
doesn't work across threads in pytorch <=2.2.2
Solution is one line.
torch.set_default_tensor_type('torch.cuda.FloatTensor')
Gradio demo (with streaming hopefully) coming soon.