logo

🌟 SaRA: эффективный файнтюн диффузионных моделей с помощью разреженной низкоранговой адаптации.

🌟 SaRA: эффективный файнтюн диффузионных моделей с помощью разреженной низкоранговой адаптации.
🌟 SaRA: эффективный файнтюн диффузионных моделей с помощью разреженной низкоранговой адаптации.

Sparse Low Rank Adaptation (SaRA) - метод дополнительного обучения для диффузионных моделей, который использует "неэффективные" параметры с наименьшими абсолютными значениями в предобученной модели.

SaRA позволяет улучшить генеративные способности модели, адаптируя ее к новым задачам, сохраняя при этом обобщающие способности исходной модели. SaRA отличается простотой реализации, требуя модификации всего одной строки кода в исходном скрипте обучения.

Идея метода о том, что параметры модели с наименьшими абсолютными значениями, хотя и не оказывают существенного влияния на инференс модели, обладают потенциалом для обучения новым знаниям. Потенциал обусловлен не структурными ограничениями модели, а скорее случайностью процесса оптимизации во время обучения.

Чтобы предотвратить переобучение, которое может возникнуть из-за сильной способности к представлению разреженных матриц, в SaRA используется функция потерь на основе ядерной нормы (nuclear norm-based) для ограничения ранга обучаемых матриц.

Для более плотного использования "неэффективных" параметров, используется прогрессивная стратегия настройки параметров процесса файнтюна - на более поздних этапах обучения происходит повторный выбор "неэффективных" параметров для повышения адаптивности модели.

Для решения проблемы высокого потребления VRAM, характерной для методов selective PEFT, SaRA использует алгоритм «неструктурного обратного распространения ошибки». Этот алгоритм хранит и обновляет градиенты только для обучаемых параметров, значительно сокращая использование памяти во время обучения.

Проведенные эксперименты на моделях Stable Diffusion (14, 1.5, 2.0, 3.0) демонстрируют эффективность SaRA в сравнении с другими методами файнтюна:

🟢LoRA: экономия 52% VRAM;

🟢LT-SFT: экономия 45% VRAM.

⚠️ Метод был успешно протестирован на venv : Python 3.9.5 и CUDA 11.8. Подробный туториал разработчик обещает выложить в репозиторий на Github до 30 сентября 2024 г.

В планах проекта - поддержка Dreambooth и Animatediff. Сроки по реализации планов не уточняются.

▶️Использование SaRA :

# easily employ SaRA to finetune the model by modifying a single line of code:
from optim import adamw
model = Initialize_model()
optimizer = adamw(model,threshold=2e-3) # <-modify this line only
for data in dataloader:
model.train()
model.save()

# Save and load only the trainable parameters
optimizer = adamw(model,threshold=2e-3)
optimizer.load($path_to_save)
torch.save(optimizer.save_params(),$path_to_save)


🟡Страница проекта
🟡Arxiv
🖥Github


@ai_machinelearning_big_data

#AI #ML #Finetuning #Diffusers #SaRA

Канал источник:@ai_machinelearning_big_data