Jak działa regresja liniowa? I czy warto ją stosować?

Kilka dni temu robiąc kawę nagle usłyszałem jakieś dziwne odgłosy dochodzące z pokoju Otysi. Zaniepokojony poszedłem zobaczyć, co się dzieje. Moja młodsza córeczka siedziała na środku swojego pokoju i plastikowym młotkiem uderzała w pudełko z zabawkami. Wszedłem do środka i zapytałem:

– Kochanie co robisz??

– Oti che ochochyc – powiedziała, co oznacza w jej języku „Oti chce otworzyć”.

– Skarbie nie musisz używać młotka, aby otworzyć pudełko. Jest dużo łatwiejszy i szybszy sposób – powiedziałem, po czym nacisnąłem mały przycisk otwierający pudełko.

Na twarzy Otysi natychmiast pojawił się uśmiech i już nigdy więcej nie próbowała otwierać pudełek przy pomocy młotka.

Regresja liniowa to jeden z najprostszych algorytmów uczenia nadzorowanego w uczeniu maszynowym. Jestem przekonany, że jeśli studiowałeś przedmioty techniczne, to na pewno się z nią zetknąłeś. Jest to jeden z najprostszych algorytmów stosowany powszechnie do prognozowania i w wielu miejscach daje wystarczające wyniki. Największą jej zaletą jest prosta konstrukcja, która zapewnia pełną interpretowalność modelu. Ale zacznijmy od początku…!

Regresja liniowa – ogólna idea

Zadaniem regresji liniowej jest po prostu dopasowanie prostej linii do danych. Warto podkreślić, że regresja liniowa przyjmuje założenie, że związek między cechami a zmienną objaśnianą jest mniej więcej liniowy.

Weźmy prosty przykład – inspiracja: Cassie Kozyrkov (tutaj link).

Mamy szklankę jogurtu. Dodajemy do niego sekretny składnik i miksujemy. Zadanie, które stoi przed nami, to oszacowanie ile kalorii będzie miał koktajl.

Niech naszą charakterystyką będzie waga składnika. Zobaczmy kilka przykładowych koktajli 🙂

Czyli nasze zadanie to tak naprawdę dopasowanie prostej linii do punktów, która będzie jak najlepiej charakteryzować zależność między wagą a ostatecznymi kaloriami.

  • y – zmienna objaśniana (ostateczna ilość kalorii),
  • x – zmienna objaśniająca (waga tajnego składnika),
  • a – współczynnik kierunkowy regresji (ang. slope),
  • b – wyraz wolny (ang. intercept).

Znajdźmy teraz najlepszą prostą! Ale jak? Po prostu wystarczy wyznaczyć linię, która jest jak najbliżej naszych punktów (wykorzystując wykres rozrzutu).

Klasyka, czyli metoda najmniejszych kwadratów

Jeszcze nigdy nie spotkałem się z sytuacją, aby w analizowanych przeze mnie danych istniała prosta przechodząca przez wszystkie punkty. Gdyby tak się stało, to szukałbym błędu w danych :).

Zatem należy wybrać metodę, która pozwala na znalezienie optymalnej prostej i która najlepiej pokaże zależność pomiędzy x i y. Tych metod jest naprawdę dużo, natomiast najpopularniejszą jest metoda najmniejszych kwadratów.

metoda najmniejszych kwadratów
W 1901 roku statystyk Karl Pearson używał „linii regresji” do określenia estymacji metodą najmniejszych kwadratów.

Metoda ta polega na tym, że należy znaleźć minimum dla sumy kwadratów różnic wartości obserwowanych i wartości z naszego równania.

Troszkę obracamy krzywą i ponownie wyliczamy te wartości. W ten sposób szukamy najlepszych parametrów.

W powyższym przypadku (1 zmienna) możemy użyć matematyki i wyprowadzić wzory na współczynnik a oraz b. Troszkę potu, łez i otrzymuje się wzory na współczynniki:

oraz:

Zatem dla naszego przypadku otrzymamy równanie:

A tutaj wizualizacja najlepiej dopasowanej krzywej do naszych danych!

UWAGA!

Powyższy przykład był bardzo prosty. Zazwyczaj mamy więcej danych i charakterystyk. Wówczas równanie przyjmuje bardziej złożoną postać:

Do znalezienia najlepszych parametrów wykorzystuje się wówczas algorytmy optymalizujące zaszyte w odpowiednich bibliotekach. Często pod spodem jest wykorzystywany gradient prosty.

Na czym polega gradient prosty?

Bardzo przemawia do mnie wyjaśnienie Aurelion’a Gerion’a z książki „Uczenie maszynowe z użyciem Scikit-Learn i TensorFlow”, więc po prostu je zacytuję:

„Załóżmy, że zabłądziliśmy w górach z powodu gęstej mgły; wyczuwamy jedynie nachylenie terenu pod stopami. Logicznym rozwiązaniem zejścia na dno doliny jest podążanie w dół po jak największej pochyłości. Dokładnie taki jest mechanizm działania gradientu prostego: algorytm mierzy lokalny gradient funkcji błędu w odniesieniu do wektora parametrów, a następnie podąża w kierunku malejącego gradientu.”

Interpretacja współczynnika kierunkowego i wyrazu wolnego

Super! Mamy nasze równanie. Teraz wystarczy tylko je zinterpretować i wiedzieć, co oznaczają parametry.

Wartość współczynnika kierunkowego mówi o wpływie zmiany x na zmienną y. Jeśli wartość a jest dodatnia, to wzrost x oznacza, że możemy się średnio spodziewać wzrostu y o a jednostek. Jeśli a ma wartość ujemną, to wraz ze wzrostem x o jednostkę, y średnio zmniejsza się o a jednostek. W naszym przypadku oznacza to, że każdy dodatkowy jeden gram to wzrost o 4.1 kalorii.

Wyraz wolny mówi jakiej wartości y powinniśmy się spodziewać dla zerowego x. W zależności od danego problemu może coś oznaczać lub nic. W naszym przykładzie można go zinterpretować jako bazę do koktajlu! I to ma sens. Natomiast w przypadku szacowania cen nieruchomości w oparciu o niektóre parametry może nic nie znaczyć. Po prostu jest za mało danych dla nieruchomości o małych metrażach. Zatem wyraz wolny w tym przypadku mógłby mówić, że nieruchomość z 0 metrów we Wrocławiu kosztuje 50.000 PLN.

Czy dobrze wybraliśmy linię, czyli magia reszt

Mamy już nasze równanie i wartości rzeczywiste. Możemy teraz wyliczyć, jaka jest różnica między tymi wartościami. Nazywamy ją resztą (ang. residual). Tak naprawdę w metodzie najmniejszych kwadratów właśnie te różnice były minimalizowane podczas liczenia parametru a oraz b.

Poniżej jest wykres reszt. Przedstawia on dla każdej wartości, jak bardzo nasze reszty różnią się od linii regresji. Najlepiej, gdy te reszty na całej długości linii są mniej więcej podobnie rozłożone. W przypadku, gdy wiemy, że nie mamy już więcej danych czy innych cech (charakterystyk), to regresja liniowa nie będzie optymalną metodą do wykorzystania.

Możemy też wyliczyć współczynnik determinacji (R²), który jest jedną z miar jakości dopasowania modelu do danych uczących. Jego opracowanie przypisuje się amerykańskiemu biologowi Sewall Wright w 1921 roku.

W naszym przypadku R² możemy interpretować, w jakim stopniu waga tajemniczego składnika wyjaśnia kaloryczność całego koktajlu.

Po przeliczeniu R² równa się 0.47, czyli inaczej 47%. Możemy zatem powiedzieć, że waga magicznego składnika „tłumaczy” 47% kalorii w naszym koktajlu.

Dla przypomnienia powiem tylko, że gdy R² = 1 wówczas nasz model jest idealny.

Regresja liniowa – wady i zalety

Zalety:
  • Prostota – dzięki swojej prostocie wykorzystywana jest w wielu dziedzinach: od matematyki, poprzez ekonomię, aż po geodezję,
  • interpretowalność – dzięki prostym wzorom bardzo łatwo wyjaśnić biznesowi lub innym osobom, jak dana cecha wpływa na wynik modelu,
  • szybkość – nawet przy dużej liczbie danych dla tych prostych algorytmów wyniki dostajemy prawie od razu.
Wady:
  • Prostota (wcześniej zaleta ;P) – świat nie składa się z prostych liniowych zależności. Gdyby tak było, to pewnie nie byłoby takiego rozwoju uczenia maszynowego.

Regresja liniowa jest dobrym wstępem, zanim zechcemy wdrażać na produkcji bardziej skomplikowane rozwiązania.

  • Wartości odstające bardzo zaburzają wyniki. Dlatego, zanim przystąpimy do pracy, warto się ich pozbyć.

Kod w Python – regresja liniowa

TUTAJ możesz pobrać pliczek z danymi, który przygotowałem na potrzeby tego artykułu.

No to import potrzebnych bibliotek:

import pandas as pd
import numpy as np

from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error, r2_score

Wczytajmy dane i losujemy 5 rekordów, by zobaczyć jak wygląda zbiór:

df = pd.read_csv('../data/kalorie.csv', sep=';')
df.sample(5)

Teraz obliczam podstawowy model regresji liniowej i wyliczam błąd, którym niech będzie RMSE:

model = LinearRegression()

X = df[['Masa']]
y = df['kalorie_koktajl']

model.fit(X,y, sample_weight=None)
y_pred = model.predict(X)

print('coefficients: ', model.coef_)
print('intercept: ', model.intercept_)
print('Root Mean Squared Error (RMSE): %.2f'% np.sqrt(mean_squared_error(y, y_pred)))
print('Coefficient of determination: %.2f'% r2_score(y, y_pred))

W tym przypadku RMSE wyniosło 93.15.

Dodajmy kolejną cechę – niech to będzie % białka w koktajlu. Przeliczmy jeszcze raz i…

X = df[['Masa','%bialka_w_koktajlu']]
y = df['kalorie_koktajl']

model.fit(X,y, sample_weight=None)
y_pred = model.predict(X)

print('coefficients: ', model.coef_)
print('intercept: ', model.intercept_)
print('Root Mean Squared Error (RMSE): %.2f'% np.sqrt(mean_squared_error(y, y_pred)))
print('Coefficient of determination: %.2f'% r2_score(y, y_pred))

Nic się nie poprawiło. Zatem ta cecha jest niepotrzebna i nie ma potrzeby komplikowania modelu.

Co jest ważniejsze: lepsze algorytmy czy wiedza domenowa?

Chciałem pokazać Ci potęgę wiedzy domenowej! Ludzie już zrozumieli czym są kalorie. Wystarczyłoby porozmawiać z osobą posiadającą wiedzę domenową w tym zakresie (np. z dietetykiem), aby nam powiedział, że magicznym składnikiem może być:

  • tłuszcz (g),
  • białgo (g),
  • węglowodany (g).

Dlaczego? Bo 1 gram tłuszczu to około 9 kalorii, a białko i węglowodany na około 4 kalorie. Bez odpowiedniej wiedzy domenowej można nie wiedzieć prostych rzeczy. Dlatego zawsze zachęcam do rozmawiania z ekspertami z dziedziny, której dotyczy model.

Uwzględniając te trzy cechy budujemy model:

X = df[['Bialka', 'Tluszcze', 'Weglowodany']]
y = df['kalorie_koktajl']

model.fit(X,y)
y_pred = model.predict(X)

print('coefficients: ', model.coef_)
print('intercept: ', model.intercept_)
print('Root Mean Squared Error (RMSE): %.2f'% np.sqrt(mean_squared_error(y, y_pred)))
print('Coefficient of determination: %.2f'% r2_score(y, y_pred))

Ostateczna postać naszego modelu to:

Błąd bardzo mocno spadł. Wynosi zaledwie 5.34. A równanie pokazuje całkiem blisko tego, co mówią dietetycy na temat odżywiania 🙂 Pewnie gdyby dodać więcej produktów, to lepiej model nauczyłby się wartości dla tłuszczu.

A czym jest nasz wyraz wolny? To jest baza naszego koktajlu, czyli 200g jogurtu naturalnego (w danych miał dokładnie 138 kalorii)!

Pewnie nie ma co szukać dokładniejszego modelu czy dodawać kolejne charakterystyki.

Podsumowanie

Nasz najlepszy model będzie taki, jakie dane do niego wrzucimy – pamiętaj o tym! Więc nawet czasami najprostsze algorytmy (np. regresja liniowa) z fajnymi i mądrymi cechami są w stanie dać bardzo zadowalający wynik.

Pozdrawiam,

podpis Mirek

.

2 Comments on “Jak działa regresja liniowa? I czy warto ją stosować?”

Dodaj komentarz

Twój adres email nie zostanie opublikowany. Pola, których wypełnienie jest wymagane, są oznaczone symbolem *