Не хочется грузить библиотеку torch - это как? Вы же её используете по полной? Через что вы вычисляете значения модели? Ведь строка "with torch.no_grad()" неспроста у вас.
По какой причине решено экономить на спичках? Время работы, экономия памяти?
Попробуй вместо "проблемного" "torch.argmax(model_out[0], dim=-1)" поставить 0 и проанализируй как поменялось использование ресурсов.