일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | 5 | 6 | 7 |
8 | 9 | 10 | 11 | 12 | 13 | 14 |
15 | 16 | 17 | 18 | 19 | 20 | 21 |
22 | 23 | 24 | 25 | 26 | 27 | 28 |
29 | 30 | 31 |
- 자연어분석
- 파이토치기본
- 딥러닝
- NLP
- HTML
- deeplearning
- python 정렬
- Python
- 판다스
- sklearn
- 파이토치
- chatGPT
- fastapi #python웹개발
- 비지도학습
- 파이썬웹개발
- OpenAIAPI
- pytorch
- konlpy
- MachineLearning
- 파이썬
- 사이킷런
- fastapi
- 챗gpt
- programmablesearchengine
- langchain
- 판다스 데이터정렬
- fastapi #파이썬웹개발
- 랭체인
- pandas
- 머신러닝
- Today
- Total
Data Navigator
[sklearn] Logistic Regression을 활용한 소비자 광고 반응률 예측 본문
[sklearn] Logistic Regression을 활용한 소비자 광고 반응률 예측
코딩하고분석하는돌스 2021. 1. 22. 23:10
Logistic Regression을 활용한 소비자 광고 반응률 예측¶
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
df = pd.read_csv('./03. 광고 반응률 예측 (Logistic Regression)/advertising.csv')
df.head(10)
Daily Time Spent on Site | Age | Area Income | Daily Internet Usage | Ad Topic Line | City | Male | Country | Timestamp | Clicked on Ad | |
---|---|---|---|---|---|---|---|---|---|---|
0 | 68.95 | NaN | 61833.90 | 256.09 | Cloned 5thgeneration orchestration | Wrightburgh | 0 | Tunisia | 3/27/2016 0:53 | 0 |
1 | 80.23 | 31.0 | 68441.85 | 193.77 | Monitored national standardization | West Jodi | 1 | Nauru | 4/4/2016 1:39 | 0 |
2 | 69.47 | 26.0 | 59785.94 | 236.50 | Organic bottom-line service-desk | Davidton | 0 | San Marino | 3/13/2016 20:35 | 0 |
3 | 74.15 | 29.0 | 54806.18 | 245.89 | Triple-buffered reciprocal time-frame | West Terrifurt | 1 | Italy | 1/10/2016 2:31 | 0 |
4 | 68.37 | 35.0 | 73889.99 | 225.58 | Robust logistical utilization | South Manuel | 0 | Iceland | 6/3/2016 3:36 | 0 |
5 | 59.99 | 23.0 | 59761.56 | 226.74 | Sharable client-driven software | Jamieberg | 1 | Norway | 5/19/2016 14:30 | 0 |
6 | 88.91 | NaN | 53852.85 | 208.36 | Enhanced dedicated support | Brandonstad | 0 | Myanmar | 1/28/2016 20:59 | 0 |
7 | 66.00 | 48.0 | 24593.33 | 131.76 | Reactive local challenge | Port Jefferybury | 1 | Australia | 3/7/2016 1:40 | 1 |
8 | 74.53 | 30.0 | 68862.00 | 221.51 | Configurable coherent function | West Colin | 1 | Grenada | 4/18/2016 9:33 | 0 |
9 | 69.88 | 20.0 | 55642.32 | 183.82 | Mandatory homogeneous architecture | Ramirezton | 1 | Ghana | 7/11/2016 1:42 | 0 |
개인의 수입을 직접적으로 알기 어렵기 때문에 AreaIncome에 지역 평균 임금을 사용해 대략적으로 개인의 임금을 추정함
df.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1000 entries, 0 to 999
Data columns (total 10 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 Daily Time Spent on Site 1000 non-null float64
1 Age 916 non-null float64
2 Area Income 1000 non-null float64
3 Daily Internet Usage 1000 non-null float64
4 Ad Topic Line 1000 non-null object
5 City 1000 non-null object
6 Male 1000 non-null int64
7 Country 1000 non-null object
8 Timestamp 1000 non-null object
9 Clicked on Ad 1000 non-null int64
dtypes: float64(4), int64(2), object(4)
memory usage: 78.2+ KB
특이한 값이 있는지 확인하기 위해서 describe() 로 확인
df.describe()
Daily Time Spent on Site | Age | Area Income | Daily Internet Usage | Male | Clicked on Ad | |
---|---|---|---|---|---|---|
count | 1000.000000 | 916.000000 | 1000.000000 | 1000.000000 | 1000.000000 | 1000.00000 |
mean | 65.000200 | 36.128821 | 55000.000080 | 180.000100 | 0.481000 | 0.50000 |
std | 15.853615 | 9.018548 | 13414.634022 | 43.902339 | 0.499889 | 0.50025 |
min | 32.600000 | 19.000000 | 13996.500000 | 104.780000 | 0.000000 | 0.00000 |
25% | 51.360000 | 29.000000 | 47031.802500 | 138.830000 | 0.000000 | 0.00000 |
50% | 68.215000 | 35.000000 | 57012.300000 | 183.130000 | 0.000000 | 0.50000 |
75% | 78.547500 | 42.000000 | 65470.635000 | 218.792500 | 1.000000 | 1.00000 |
max | 91.430000 | 61.000000 | 79484.800000 | 269.960000 | 1.000000 | 1.00000 |
Area Income에서 min과 max 값의 차이가 좀 크다는 것을 알 수 있음
df['Area Income']
0 61833.90
1 68441.85
2 59785.94
3 54806.18
4 73889.99
...
995 71384.57
996 67782.17
997 42415.72
998 41920.79
999 29875.80
Name: Area Income, Length: 1000, dtype: float64
sns.distplot(df['Area Income'])
d:\ProgramData\Anaconda3\envs\bigdata\lib\site-packages\seaborn\distributions.py:2557: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms).
warnings.warn(msg, FutureWarning)
<AxesSubplot:xlabel='Area Income', ylabel='Density'>
sns.distplot(df['Age'])
d:\ProgramData\Anaconda3\envs\bigdata\lib\site-packages\seaborn\distributions.py:2557: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms).
warnings.warn(msg, FutureWarning)
<AxesSubplot:xlabel='Age', ylabel='Density'>
df['Country'].nunique()
237
df['City'].nunique()
969
df['Ad Topic Line'].nunique()
1000
missing value (결측치) 처리하기 (age 컬럼)
df.isna()
Daily Time Spent on Site | Age | Area Income | Daily Internet Usage | Ad Topic Line | City | Male | Country | Timestamp | Clicked on Ad | |
---|---|---|---|---|---|---|---|---|---|---|
0 | False | True | False | False | False | False | False | False | False | False |
1 | False | False | False | False | False | False | False | False | False | False |
2 | False | False | False | False | False | False | False | False | False | False |
3 | False | False | False | False | False | False | False | False | False | False |
4 | False | False | False | False | False | False | False | False | False | False |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
995 | False | False | False | False | False | False | False | False | False | False |
996 | False | False | False | False | False | False | False | False | False | False |
997 | False | False | False | False | False | False | False | False | False | False |
998 | False | False | False | False | False | False | False | False | False | False |
999 | False | False | False | False | False | False | False | False | False | False |
1000 rows × 10 columns
파이썬의 기본 함수인 sum()을 이용해서 Na값이 몇개인지 찾음
df.isna().sum()
Daily Time Spent on Site 0
Age 84
Area Income 0
Daily Internet Usage 0
Ad Topic Line 0
City 0
Male 0
Country 0
Timestamp 0
Clicked on Ad 0
dtype: int64
df.isna().sum() / len(df)
Daily Time Spent on Site 0.000
Age 0.084
Area Income 0.000
Daily Internet Usage 0.000
Ad Topic Line 0.000
City 0.000
Male 0.000
Country 0.000
Timestamp 0.000
Clicked on Ad 0.000
dtype: float64
inpute => missing value를 처리하는 방법¶
1. 삭제¶
삭제하기 dropna(inplace=True), df2 = dropna()
df.dropna()
Daily Time Spent on Site | Age | Area Income | Daily Internet Usage | Ad Topic Line | City | Male | Country | Timestamp | Clicked on Ad | |
---|---|---|---|---|---|---|---|---|---|---|
1 | 80.23 | 31.0 | 68441.85 | 193.77 | Monitored national standardization | West Jodi | 1 | Nauru | 4/4/2016 1:39 | 0 |
2 | 69.47 | 26.0 | 59785.94 | 236.50 | Organic bottom-line service-desk | Davidton | 0 | San Marino | 3/13/2016 20:35 | 0 |
3 | 74.15 | 29.0 | 54806.18 | 245.89 | Triple-buffered reciprocal time-frame | West Terrifurt | 1 | Italy | 1/10/2016 2:31 | 0 |
4 | 68.37 | 35.0 | 73889.99 | 225.58 | Robust logistical utilization | South Manuel | 0 | Iceland | 6/3/2016 3:36 | 0 |
5 | 59.99 | 23.0 | 59761.56 | 226.74 | Sharable client-driven software | Jamieberg | 1 | Norway | 5/19/2016 14:30 | 0 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
995 | 72.97 | 30.0 | 71384.57 | 208.58 | Fundamental modular algorithm | Duffystad | 1 | Lebanon | 2/11/2016 21:49 | 1 |
996 | 51.30 | 45.0 | 67782.17 | 134.42 | Grass-roots cohesive monitoring | New Darlene | 1 | Bosnia and Herzegovina | 4/22/2016 2:07 | 1 |
997 | 51.63 | 51.0 | 42415.72 | 120.37 | Expanded intangible solution | South Jessica | 1 | Mongolia | 2/1/2016 17:24 | 1 |
998 | 55.55 | 19.0 | 41920.79 | 187.95 | Proactive bandwidth-monitored policy | West Steven | 0 | Guatemala | 3/24/2016 2:35 | 0 |
999 | 45.01 | 26.0 | 29875.80 | 178.35 | Virtual 5thgeneration emulation | Ronniemouth | 0 | Brazil | 6/3/2016 21:43 | 1 |
916 rows × 10 columns
df.drop('Age', axis=1) 로 Age 컬럼을 삭제 => 하지만 중요한 데이터일 수도 있기 때문에 삭제하는 방법은 좋지 않음
2. 다른 값으로 처리 평균(mean) 값이나 중간값(median)으로 처리하는 것이 효율이 좋음¶
데이터 값이 비교적 고르게 분포할 때는 평균이 좋음
round(df['Age'].mean())
36
극단 값이 많을 때는 중간값이 좋다.
df['Age'].median()
35.0
sns.distplot(df['Age'])
d:\ProgramData\Anaconda3\envs\bigdata\lib\site-packages\seaborn\distributions.py:2557: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms).
warnings.warn(msg, FutureWarning)
<AxesSubplot:xlabel='Age', ylabel='Density'>
df = df.fillna(round(df['Age'].mean()))
df.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1000 entries, 0 to 999
Data columns (total 10 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 Daily Time Spent on Site 1000 non-null float64
1 Age 1000 non-null float64
2 Area Income 1000 non-null float64
3 Daily Internet Usage 1000 non-null float64
4 Ad Topic Line 1000 non-null object
5 City 1000 non-null object
6 Male 1000 non-null int64
7 Country 1000 non-null object
8 Timestamp 1000 non-null object
9 Clicked on Ad 1000 non-null int64
dtypes: float64(4), int64(2), object(4)
memory usage: 78.2+ KB
df.isna().sum()
Daily Time Spent on Site 0
Age 0
Area Income 0
Daily Internet Usage 0
Ad Topic Line 0
City 0
Male 0
Country 0
Timestamp 0
Clicked on Ad 0
dtype: int64
missing value가 너무 많다면 impute 하지 말고 삭제할 수도 있다. 하지만 고의적으로 정보를 숨기기 위해서 값을 안준 것일 수도 있기 때문에 주의
from sklearn.model_selection import train_test_split
X = df[['Daily Time Spent on Site','Age', 'Area Income', 'Daily Internet Usage','Male']] #독립변수
y = df['Clicked on Ad'] #종속변수
X
Daily Time Spent on Site | Age | Area Income | Daily Internet Usage | Male | |
---|---|---|---|---|---|
0 | 68.95 | 36.0 | 61833.90 | 256.09 | 0 |
1 | 80.23 | 31.0 | 68441.85 | 193.77 | 1 |
2 | 69.47 | 26.0 | 59785.94 | 236.50 | 0 |
3 | 74.15 | 29.0 | 54806.18 | 245.89 | 1 |
4 | 68.37 | 35.0 | 73889.99 | 225.58 | 0 |
... | ... | ... | ... | ... | ... |
995 | 72.97 | 30.0 | 71384.57 | 208.58 | 1 |
996 | 51.30 | 45.0 | 67782.17 | 134.42 | 1 |
997 | 51.63 | 51.0 | 42415.72 | 120.37 | 1 |
998 | 55.55 | 19.0 | 41920.79 | 187.95 | 0 |
999 | 45.01 | 26.0 | 29875.80 | 178.35 | 0 |
1000 rows × 5 columns
y
0 0
1 0
2 0
3 0
4 0
..
995 1
996 1
997 1
998 0
999 1
Name: Clicked on Ad, Length: 1000, dtype: int64
데이터가 1000단위면 트레이닝 데이터와 테스트 데이터 비율을 8:2, 9:1 정도로 잡는 것이 좋다.¶
random_state = 100 을 해주는 이유는 랜덤으로 뽑는 값을 같게 하기 위해서이다.¶
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state = 100)
X_train
Daily Time Spent on Site | Age | Area Income | Daily Internet Usage | Male | |
---|---|---|---|---|---|
675 | 82.58 | 38.0 | 65496.78 | 225.23 | 1 |
358 | 51.38 | 59.0 | 42362.49 | 158.56 | 0 |
159 | 75.55 | 36.0 | 73234.87 | 159.24 | 0 |
533 | 91.43 | 36.0 | 46964.11 | 209.91 | 1 |
678 | 87.85 | 34.0 | 51816.27 | 153.01 | 0 |
... | ... | ... | ... | ... | ... |
855 | 50.87 | 24.0 | 62939.50 | 190.41 | 0 |
871 | 76.79 | 27.0 | 55677.12 | 235.94 | 0 |
835 | 63.11 | 34.0 | 63107.88 | 254.94 | 1 |
792 | 56.56 | 26.0 | 68783.45 | 204.47 | 1 |
520 | 46.61 | 42.0 | 65856.74 | 136.18 | 0 |
800 rows × 5 columns
X_test
Daily Time Spent on Site | Age | Area Income | Daily Internet Usage | Male | |
---|---|---|---|---|---|
249 | 62.20 | 25.0 | 25408.21 | 161.16 | 0 |
353 | 79.54 | 44.0 | 70492.60 | 217.68 | 1 |
537 | 61.72 | 26.0 | 67279.06 | 218.49 | 0 |
424 | 43.59 | 36.0 | 58849.77 | 132.31 | 1 |
564 | 64.75 | 36.0 | 63001.03 | 117.66 | 0 |
... | ... | ... | ... | ... | ... |
684 | 42.06 | 34.0 | 43241.19 | 131.55 | 0 |
644 | 78.35 | 46.0 | 53185.34 | 253.48 | 0 |
110 | 66.63 | 60.0 | 60333.38 | 176.98 | 0 |
28 | 70.20 | 34.0 | 32708.94 | 119.20 | 0 |
804 | 53.92 | 41.0 | 25739.09 | 125.46 | 1 |
200 rows × 5 columns
사이킷런의 Logistic Regression¶
from sklearn.linear_model import LogisticRegression
model = LogisticRegression()
model.fit(X_train, y_train)
LogisticRegression()
coef는 절대값이 클 수록 설명력이 좋다. 단, 변수간 단위가 일정할 때¶
model.coef_
array([[-6.64737762e-02, 2.66015818e-01, -1.15501902e-05,
-2.44285539e-02, 2.00758165e-03]])
pred = model.predict(X_test)
y_test
249 1
353 0
537 0
424 1
564 1
..
684 1
644 0
110 1
28 1
804 1
Name: Clicked on Ad, Length: 200, dtype: int64
from sklearn.metrics import accuracy_score, confusion_matrix
accuracy_score(y_test, pred)
0.9
confusion_matrix(y_test, pred)
array([[92, 8],
[12, 88]], dtype=int64)
파이썬 팁¶
컬럼에 같은 값이 몇 개 있는지 볼때 쓰는 함수¶
df['Country'].nunique()
237
value_counts()를 이용해서 같은 값이 몇 개 있는지 셀 수 있음.¶
df['Country'].value_counts()
Czech Republic 9
France 9
Senegal 8
Cyprus 8
Australia 8
..
Kiribati 1
Slovenia 1
Aruba 1
Jordan 1
Bermuda 1
Name: Country, Length: 237, dtype: int64
Logistic Regression의 원리¶
참 / 거짓, 선택 / 비선택 같은 두가지로 나누어지는 경우 Linear Regression으로 분석하면 1 이상 혹은 0 미만으로 0~1 범위를 넘어가는 경우가 생긴다. 그 때 0~1을 넘어가는 값은 의미가 없으므로 로짓변환을 통해 S 커브의 모양으로 그래프를 바꿔서 그려 0이나 1에 가까워지는 값을 갖도록 한다.
Binary Classification 풀어야 할 문제의 유형에 대한 문제¶
Confusion matrix와 Type-1, Type-2 Error¶
Type-1, Type-2 Error¶
Type-2 error 가 중요한 경우¶
위의 두 에러중에서 Type-2 error은 심각한 위험을 초래하기 때문에 예측 모델을 만들 때 Type-2 error 발생이 적게 나오도록 조정해야 한다.¶
Type-1 error 가 중요한 경우¶
'Machine Learning, Deep Learning' 카테고리의 다른 글
[sklearn] KNN(K Nearlist Neighbors) 알고리즘을 이용하여 고객이탈 예측하기 (0) | 2021.01.27 |
---|---|
[sklearn, NLP] 상품 리뷰 분석 NLP, Count Vectorizer, Naive Bayes Classifier (0) | 2021.01.24 |
[sklearn,statsmodels] Linear Regression을 이용한 고객별 연간 지출액 예측 (0) | 2021.01.22 |
[sklearn, SVM] 서포트 백터 머신을 이용해서 외국어 문장 판별하기 (0) | 2021.01.17 |
1. 머신러닝(Machine Learning) - 1일차 머신러닝 개념 및 종류 (0) | 2021.01.04 |