from sklearn.linear_model import LinearRegression, Ridge, Lasso

# Class to hold the best linear model
class class_best_regressor:
	score = 0
	model = LinearRegression().fit(X_train, y_train)

	def check_regressor(self, regressor):
		new_score = regressor.score(X_test, y_test_log)
		if(new_score > self.score):
			self.score = new_score
			self.model = regressor

best_regressor = class_best_regressor()

# Simple linear regression
linreg = LinearRegression().fit(X_train, y_train)

best_regressor.check_regressor(linreg)

# Lasso model

for alpha in [0.00001, 0.001, 0.005, 0.05, 0.5, 1]:
	linlasso = Lasso(alpha=alpha, max_iter = 100_000).fit(X_train, y_train)
	best_regressor.check_regressor(linlasso)

# Ridge Model

for alpha in [0.00001, 0.001, 0.005, 0.05, 0.5, 1]:
	linridge = Ridge(alpha=alpha, max_iter = 10000).fit(X_train, y_train)
	best_regressor.check_regressor(linridge)

# Print best linear model

print(f'Best regressor: {best_regressor.model}')

# Plot best model

fig, axs =  plt.subplots(2,2, figsize=(10,10))
fig.subtitle('Linear Regression (Simple, Lasso with )')
axs = iter(axs.reshape(1, -1)[0])

# Plotting of the train dataset
ax = next(axs)
y_hat_train = pd.Series(best_regressor.model.predict(X_train), index=y_train.index).sort_values()
sns.scatterplot(x=np.expm1(y_hat_train), y=np.expm1(y_train), marker= 'o', s=50, alpha=0.8, ax=ax)
ax.set(ylabel='Real value',
       xlabel='Predicted value',
	   title=f'Linear Regression\\nTrain Dataset - R²={best_regressor.model.score(X_train, y_train):.3f}'
)

# Plotting of the test dataset
ax = next(axs)
y_hat = pd.Series(best_regressor.model.predict(X_test), index=X_test.index).sort_values()
sns.scatterplot(x=np.expm1(y_hat), y=np.expm1(y_test_log.loc[y_hat.index]), marker= 'o', s=50, alpha=0.8, ax=ax)
ax.set(ylabel='Real value',
       xlabel='Predicted value',
	   title=f'Linear Regression\\nTest Dataset - R²={best_regressor.model.score(X_test, y_test_log):.3f}'
)

# Plotting of residuals Scatterplot

ax = next(axs)
sorted_y_test = y_test_log.sort_values()
sorted_x_test = X_test.reindex(sorted_y_test.index)
predicted_values = np.expm1(best_regressor.model.predict(sorted_x_test))
residuals = (np.expm1(sorted_y_test) - predicted_values)
sns.scatterplot(x=predicted_values, y=residuals, ax=ax)
ax.set(ylabel='Residual',
       xlabel='Predicted value',
	   title=f'Residuals'
);

# Plotting of residuals Histogram

ax = next(axs)
sns.histplot(data=residuals, ax=ax, kde=True)
ax.set(ylabel='Residual count',
       xlabel='Predicted value',
	   title=f'Residuals histogram'
);