“This is the 12th day of my participation in the Gwen Challenge.
1. The background
Generally speaking, salaries increase with years of service, and the rate of increase varies by industry. This case is to use a linear regression model to explore the impact of length of service on salary, that is, to build a salary prediction model, and by comparing the salary prediction model of multiple industries to analyze the characteristics of each industry.
2. Read the data
Firstly, take IT industry as an example to analyze, the monthly salary data of 100 IT engineers with working experience of 0 ~ 8 years in IT industry in Beijing are selected and stored in an Excel workbook named “IT Industry Income table.xLSx”. Read the data with the following code
Import pandas as pd df= pd.read_excel(' XLSX ') df.headCopy the code
After running, you can see the first five rows of the table, as shown in the table below.
At this time, the length of service is the independent variable and the salary is the dependent variable. The independent variable and the dependent variable are selected by the following code.
X=df[[' salary ']] Y=df[' salary ']Copy the code
The independent variable X here has to be written in two dimensions, for reasons I mentioned earlier; The dependent variable Y can be written as a one-dimensional structure, but if written as a two-dimensional structure df[[‘ salary ‘]], the subsequent model will also work.
The scatter diagram can be drawn by the following code.
From matplotlib import Pyplot as PLT plt.rcparams ['font. Sans-serif ']=['SiMhei'] plt.scatter(X,Y) plt.xlabel(' Working age ') Plt.ylabel (' salary ') plt.show()Copy the code
The second line of code displays the Chinese label normally, SimHei for “bold” font; Lines 4 and 5 add axis labels using the plt.xlabel() and plt.ylabel() functions. The drawing effect is shown below.
3. Model structures,
The following code is used to build the linear regression model.
from sklearn.linear_model import LinearRegression
regr = LinearRegression()
regr.fit(X,Y)
Copy the code
4. Model visualization
The linear regression model is visualized by the following code.
Plt.plot (X,regr.predict(X),color='red') plt.xlabel(' length of service ') plt.ylabel(' salary ') plt.show()Copy the code
The color=’red’ in line 2 means that the trend line is drawn in red. The drawing effect is shown below
5. Linear regression equation construction
View the coefficients A and intercept B of the trend line using the following code.
Print (' coefficient of a: + STR (regr... coef_ [0])) print (' intercept b: "+ STR (regr. Intercept_))Copy the code
The result is as follows. Therefore, the unary linear regression equation obtained by fitting is Y = 2497x+10143.