Tehnicile de învățare profundă de ultimă generație se bazează pe modele supra-parametrizate greu de implementat. Dimpotrivă, se știe că rețelele neuronale biologice utilizează o conectivitate redusă eficientă. Identificarea tehnicilor optime pentru comprimarea modelelor prin reducerea numărului de parametri din ele este importantă pentru a reduce consumul de memorie, baterie și hardware fără a sacrifica precizia, a implementa modele ușoare pe dispozitiv și pentru a garanta confidențialitatea cu calculul privat pe dispozitiv. Pe frontul cercetării, tăierea este utilizată pentru a investiga diferențele din dinamica învățării dintre rețelele supra-parametrizate și sub-parametrizate, pentru a studia rolul subrețelelor și inițializărilor rare („bilete de loterie”) ca o tehnică de căutare a arhitecturii neuronale distructive și Marea.

tăiere

În acest tutorial, veți învăța cum să utilizați torch.nn.utils.prune pentru a sparsifica rețelele neuronale și cum să îl extindeți pentru a implementa propria tehnică de tăiere personalizată.

Cerințe¶

Creați un model

În acest tutorial, folosim arhitectura LeNet de la LeCun și colab., 1998.

Inspectați un modul¶

Să inspectăm stratul de conv1 (netuns) din modelul nostru LeNet. Deocamdată va conține greutatea și părtinirea a doi parametri și nu există tampoane.

Tunderea unui modul¶

Pentru a tăia un modul (în acest exemplu, stratul conv1 al arhitecturii noastre LeNet), selectați mai întâi o tehnică de tăiere dintre cele disponibile în torch.nn.utils.prune (sau implementați-vă propriul prin subclasarea BasePruningMethod). Apoi, specificați modulul și numele parametrului de tăiat în cadrul acelui modul. În cele din urmă, utilizând argumentele adecvate ale cuvintelor cheie solicitate de tehnica de tăiere selectată, specificați parametrii de tăiere.

În acest exemplu, vom tăia la întâmplare 30% din conexiunile din parametrul numit greutate în stratul conv1. Modulul este transmis ca primul argument către funcție; numele identifică parametrul din cadrul acelui modul utilizând identificatorul său de șir; și suma indică fie procentul conexiunilor la prune (dacă este un float între 0. și 1.), fie numărul absolut de conexiuni la prune (dacă este un număr întreg negativ).

Tunderea acționează prin eliminarea greutății din parametri și înlocuirea acesteia cu un nou parametru numit weight_orig (adică adăugarea „_orig” la numele parametrului inițial). weight_orig stochează versiunea netezată a tensorului. Tendința nu a fost tăiată, așa că va rămâne intactă.

Masca de tăiere generată de tehnica de tăiere selectată mai sus este salvată ca tampon de modul numit weight_mask (adică adăugând „_mask” la numele parametrului inițial).

Pentru ca trecerea înainte să funcționeze fără modificări, atributul de greutate trebuie să existe. Tehnicile de tăiere implementate în torch.nn.utils.prune calculează versiunea tăiată a greutății (prin combinarea măștii cu parametrul original) și le stochează în greutatea atributului. Rețineți, acesta nu mai este un parametru al modulului, acum este pur și simplu un atribut.

În cele din urmă, tăierea se aplică înainte de fiecare trecere înainte folosind forward_pre_hooks ale PyTorch. Mai exact, atunci când modulul este tăiat, așa cum am făcut aici, acesta va achiziționa un forward_pre_hook pentru fiecare parametru asociat acestuia care este tăiat. În acest caz, deoarece până acum am tăiat doar parametrul original numit greutate, va fi prezent un singur cârlig.

Pentru completare, putem acum să tundem și părtinirea, pentru a vedea cum se modifică parametrii, tampoanele, cârligele și atributele modulului. Doar pentru a încerca o altă tehnică de tăiere, aici tăiem cele mai mici 3 intrări în prejudecată prin norma L1, așa cum este implementat în funcția de tăiere l1_unstructured.

Acum ne așteptăm ca parametrii numiți să includă atât weight_orig (dinainte), cât și bias_orig. Tampoanele vor include weight_mask și bias_mask. Versiunile tăiate ale celor două tensoare vor exista ca atribute ale modulului, iar modulul va avea acum două forward_pre_hooks .

Tundere iterativă¶

Același parametru dintr-un modul poate fi tăiat de mai multe ori, efectul diferitelor apeluri de tăiere fiind egal cu combinația diferitelor măști aplicate în serie. Combinarea unei noi măști cu vechea mască este gestionată de metoda compute_mask a PruningContainer.

Spuneți, de exemplu, că acum dorim să mai curățăm modulul greutate, de data aceasta folosind tăierea structurată de-a lungul axei a 0 a tensorului (axa a 0 corespunde canalelor de ieșire ale stratului convoluțional și are dimensionalitate 6 pentru conv1), pe baza pe norma L2 a canalelor. Acest lucru poate fi realizat folosind funcția ln_structured, cu n = 2 și dim = 0 .

Cârligul corespunzător va fi acum de tip torch.nn.utils.prune.PruningContainer și va stoca istoricul tăierii aplicat parametrului de greutate.

Serializarea unui model tăiat¶

Toți tensorii relevanți, inclusiv tampoanele de mască și parametrii originali utilizați pentru calcularea tensoarelor tăiate sunt stocate în starea_dict a modelului și, prin urmare, pot fi ușor serializate și salvate, dacă este necesar.

Eliminați parametrizarea din nou a tăierii¶

Pentru a face tăierea permanentă, eliminați re-parametrizarea în termeni de weight_orig și weight_mask și eliminați forward_pre_hook, putem utiliza funcționalitatea de eliminare din torch.nn.utils.prune. Rețineți că acest lucru nu anulează tăierea, ca și cum nu s-ar fi întâmplat niciodată. Pur și simplu îl face permanent, reasignând greutatea parametrilor la parametrii modelului, în versiunea sa tăiată.

Înainte de a elimina re-parametrizarea:

După eliminarea re-parametrizării:

Tunderea mai multor parametri într-un model¶

Prin specificarea tehnicii și parametrilor de tăiere dorită, putem tăia cu ușurință mai mulți tensori într-o rețea, poate în funcție de tipul lor, așa cum vom vedea în acest exemplu.

Tunderea globală

Până acum, ne-am uitat doar la ceea ce se numește de obicei tăiere „locală”, adică practica tăierii tensoarelor într-un model unul câte unul, prin compararea statisticilor (mărimea greutății, activare, gradient etc.) ale fiecărei intrări exclusiv cu celelalte intrări din acel tensor. Cu toate acestea, o tehnică obișnuită și poate mai puternică este de a tăia modelul dintr-o dată, eliminând (de exemplu) cel mai mic 20% din conexiunile din întregul model, în loc să eliminați cel mai mic 20% din conexiunile din fiecare strat. Acest lucru va duce probabil la procente diferite de tăiere pe strat. Să vedem cum să facem acest lucru folosind global_unstructured from torch.nn.utils.prune .

Acum putem verifica raritatea indusă în fiecare parametru tăiat, care nu va fi egal cu 20% în fiecare strat. Cu toate acestea, raritatea globală va fi (aproximativ) de 20%.

Extinderea torch.nn.utils.prune cu funcții de tăiere personalizate¶

Pentru a implementa propria funcție de tăiere, puteți extinde modulul nn.utils.prune subclasând clasa de bază BasePruningMethod, la fel ca toate celelalte metode de tăiere. Clasa de bază implementează următoarele metode pentru dvs.: __cell__, apply_mask, apply, tund și eliminați. Dincolo de unele cazuri speciale, nu ar trebui să reimplementați aceste metode pentru noua dvs. tehnică de tăiere. Cu toate acestea, va trebui să implementați __init__ (constructorul) și compute_mask (instrucțiunile despre cum să calculați masca tensorului dat conform logicii tehnicii dvs. de tăiere). În plus, va trebui să specificați ce tip de tăiere implementează această tehnică (opțiunile acceptate sunt globale, structurate și nestructurate). Acest lucru este necesar pentru a determina cum să combinați măștile în cazul în care tăierea este aplicată iterativ. Cu alte cuvinte, când se tunde un parametru pre-tăiat, se așteaptă ca tehnica curentă de tăiere să acționeze asupra porțiunii neprăpate a parametrului. Specificarea PRUNING_TYPE va permite PruningContainer (care se ocupă de aplicația iterativă a măștilor de tăiere) să identifice corect felia de parametru de tăiat.

Să presupunem, de exemplu, că doriți să implementați o tehnică de tăiere care să tundă toate celelalte intrări dintr-un tensor (sau - dacă tensorul a fost tăiat anterior - în porțiunea rămasă a tunsorului). Acest lucru va avea PRUNING_TYPE = „nestructurat” deoarece acționează asupra conexiunilor individuale dintr-un strat și nu asupra unităților/canalelor întregi („structurat”) sau a diferiților parametri („global”).