class: ur-title, center, middle, title-slide .title[ # BST430 Lecture 13 ] .subtitle[ ## Multivariate linear models ] .author[ ### Seong-Hwan Jun, based on the course by Andrew McDavid and Tanzy Love ] .institute[ ### U of Rochester ] .date[ ### 2021-11-07 (updated: 2025-10-16) ] --- # Agenda 0. Multivariate plots 1. More modeling syntax 3. Diagnostics 4. Interpreting interaction models 4. Many models Here's the [R code in this lecture](l12/multivariate-lm.R) --- class: code70 ### `mtcars` - For this lecture, we will use the (infamous) `mtcars` dataset that comes with R by default. ```r library(tidyverse) library(tidymodels) library(broom) data("mtcars") glimpse(mtcars) ``` ``` ## Rows: 32 ## Columns: 11 ## $ mpg <dbl> 21.0, 21.0, 22.8, 21.4, 18.7, 18.1, 14.3, 24.4, 22… ## $ cyl <dbl> 6, 6, 4, 6, 8, 6, 8, 4, 4, 6, 6, 8, 8, 8, 8, 8, 8,… ## $ disp <dbl> 160.0, 160.0, 108.0, 258.0, 360.0, 225.0, 360.0, 1… ## $ hp <dbl> 110, 110, 93, 110, 175, 105, 245, 62, 95, 123, 123… ## $ drat <dbl> 3.90, 3.90, 3.85, 3.08, 3.15, 2.76, 3.21, 3.69, 3.… ## $ wt <dbl> 2.620, 2.875, 2.320, 3.215, 3.440, 3.460, 3.570, 3… ## $ qsec <dbl> 16.46, 17.02, 18.61, 19.44, 17.02, 20.22, 15.84, 2… ## $ vs <dbl> 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0,… ## $ am <dbl> 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,… ## $ gear <dbl> 4, 4, 4, 3, 3, 3, 3, 4, 4, 4, 4, 3, 3, 3, 3, 3, 3,… ## $ carb <dbl> 4, 4, 1, 1, 2, 1, 4, 2, 2, 4, 4, 3, 3, 3, 4, 4, 4,… ``` --- class: middle .hand[Visualizing multivariate relationships] --- ### Exploratory data analyses of multivariate data - Hard - Necessary - Often generates quite a few plots before you identify one that you can keep --- ### One simple trick Pairs plots! .panelset[ .panel[.panel-name[Code] ```r library(GGally) GGally::ggpairs(mtcars) ``` ] .panel[.panel-name[Plot] <img src="l13-multivariate-lm_files/figure-html/unnamed-chunk-4-1.png" width="70%" style="display: block; margin: auto;" /> ] ] --- * Evidently `vs`, `am`, `gear` and maybe `cyl` should be cast to factors .panelset[ .panel[.panel-name[Code] ```r mtcars = mtcars %>% mutate(across(c(vs, am, gear, cyl), factor)) GGally::ggpairs(mtcars) ``` ] .panel[.panel-name[Plot] <img src="l13-multivariate-lm_files/figure-html/unnamed-chunk-5-1.png" width="70%" style="display: block; margin: auto;" /> ] ] --- - Let's suppose we wanted to determine which variables affect fuel consumption (the `mpg` variable in the dataset). - `cyl`, `disp`, `hp`, `drat`, `wt` all appear to be correlated with `mpg` - `vs`, `am` and possibly `gear` could have distributional differences, as well --- class: code70 - To begin, we'll look at the association between the variables log-weight and mpg ```r ggplot(mtcars, aes(x = wt, mpg)) + geom_point() + scale_x_continuous(transform = "log") + xlab("Weight") + ylab("Miles Per Gallon") ``` <img src="l13-multivariate-lm_files/figure-html/unnamed-chunk-6-1.png" width="60%" style="display: block; margin: auto;" /> What do you notice? --- ## Quiz - Draw a line through the points using `geom_smooth`. What do you think is the sign of the slope of the line? - Would you say log-weight is positive or negatively associated with mpg? --- ### OLS fits in R 1. Make sure you have the explanatory variables in the format you want: ```r submt = mtcars %>% mutate(logwt = log(wt)) ``` 2. Use `lm()` ```r lmout = lm(mpg ~ logwt, data = submt) lmtide = tidy(lmout) select(lmtide, term, estimate) ``` ``` ## # A tibble: 2 × 2 ## term estimate ## <chr> <dbl> ## 1 (Intercept) 39.3 ## 2 logwt -17.1 ``` --- ### `lm` syntax ```r lm(response ~ pred1 + pred2*pred3, data = data) ``` Finds the OLS estimates of the following model: > response = `\(\beta_0+ \beta_1\text{pred1} + \beta_2\text{pred2} + \beta_{3}\text{pred3}+ \beta_{23}\text{pred2} * \text{pred3}\)` + error The `data` argument tells `lm()` where to find the response and explanatory variables. --- ### Formula syntax * `x:z` form the interaction between `x` and `z` -- this is element-by-element multiplication if at least `x` or `z` is continuous, otherwise it is an outer-product. * `x*z` form interactions and include main effects. Equivalent to `x + z + x:z` * `+ 0` or `- 1` exclude an intercept term, or `- x` exclude `x` if included otherwise. * `(x + z + u)^2` form all two-way interactions and main effects with `x`, `z`, `u`. * `I(x - 10)` or `I(x^2)` -- perform arithmetic operations that use `+`, `-`, `*`, `^`, etc, on a variable "on-the-fly" in the formula. * better than just transform before you model * `y ~ .` every variable in the data, except the response `y`. --- class: code50 ### Quiz .alert[Don't do this without thinking!] ```r tidy(lm(mpg ~ ., data = mtcars)) ``` ``` ## # A tibble: 13 × 5 ## term estimate std.error statistic p.value ## <chr> <dbl> <dbl> <dbl> <dbl> ## 1 (Intercept) 15.1 17.1 0.881 0.389 ## 2 cyl6 -1.20 2.39 -0.502 0.621 ## 3 cyl8 3.05 4.83 0.633 0.535 ## 4 disp 0.0126 0.0177 0.708 0.487 ## # ℹ 9 more rows ``` This regresses mpg on all of the variables. - Examine the coefficients (estimates). Is there any value that you find counterintuitive? - What do the standard errors, statistic, and p-value represent? `? mtcars` to find the details for each variable. --- ### How does `lm` work? How do we estimate `\(\hat{\beta}\)`? `\begin{align} y_i = \beta^{\top} x_i + \epsilon_i, \quad i = 1, ..., n \end{align}` - `\(y_i\)` is the response for observation `\(i\)`. - `\(x_i\)` is a vector of explanatory variables for `\(i\)`. - `\(\epsilon_i\)` represents unexplained variation (error). --- ### How does `lm` work? We minimize the sum of squared errors: `\begin{align} \hat{\beta} = \arg\min_{\beta} \sum_{i=1}^{n} (y_i - \beta^{\top} x_i)^2. \end{align}` - `\(e_i = (y_i - \beta^{\top} x_i)\)` denotes the error for observation `\(i\)`. Q: How do we minimize this? --- ### `\(R^2\)`: R squared `\(R^2\)` is a measure of how well the model fits the data. Let `\(\bar{y} = \frac{1}{n} \sum_{i=1}^{n} y_i\)`. `\begin{align} \text{SST} &= \sum_i (y_i - \bar{y})^2 \\ \text{SSE} &= \sum_i e_i^2 \\ R^2 &= 1 - \frac{SSE}{SST} \end{align}` - Q: Does higher value of `\(R^2\)` indicate better or worse fit? - Answer it in terms of sum of squares of regression. --- - Use `broom::glance()` function to get the estimated standard deviation, `\(R^2\)`, and information criteria. It's the value in the `sigma` column. ```r glance(lmout) ``` ``` ## # A tibble: 1 × 12 ## r.squared adj.r.squared sigma statistic p.value df logLik ## <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> ## 1 0.810 0.804 2.67 128. 2.39e-12 1 -75.8 ## # ℹ 5 more variables: AIC <dbl>, BIC <dbl>, deviance <dbl>, ## # df.residual <int>, nobs <int> ``` --- ### Quiz - Choose a subset of the variables to include in the regression model for `mpg` (around 2 or 3 variables). - Justify why you chose these variables. Generate a figure to support your choice. - Fit the model and compare the `\(R^2\)` value to the model fitted using all of the variables. - Now fit the model using only the log weight variable. Compare the `\(R^2\)` values. - Comment on your findings. --- # Prediction (Interpolation) - **Interpolation**: Making estimates/predictions within the range of the data. - **Extrapolation**: Making estimates/predictions outside the range of the data. - Interpolation is fine. Extrapolation is dangerous. .question[why?] --- - Interpolation <img src="l13-multivariate-lm_files/figure-html/unnamed-chunk-11-1.png" width="60%" style="display: block; margin: auto;" /> --- - Extrapolation <img src="l13-multivariate-lm_files/figure-html/unnamed-chunk-12-1.png" width="60%" style="display: block; margin: auto;" /> --- ## Why is extrapolation dangerous? 1. Not sure if the linear relationship is the same outside the range of the data (because we don't have data there to see the relationship). 2. Not sure if the variability is the same outside the range of the data (because we don't have data there to see the variability). --- ### Make a prediction: 1. You need a data frame with the exact same variable name as the explanatory variable. ```r newdf = tribble(~logwt, 1, 1.5) ``` 2. Then you use the `predict()` function to obtain predictions. ```r newdf = newdf %>% mutate(predictions = predict(object = lmout, newdata = newdf)) ``` --- ## Assumptions and Violations - In the linear model, you can trade assumptions for inference: - Assumptions in *decreasing* order of importance 1. **Independence** - The knowledge of the value of one observation does not give you any information on the value of another. 2. **Linearity** - The relationship looks like a straight line. 3. **Equal Variance** - The spread is the same for every value of `\(x\)` 4. **Normality** - The distribution of the errors isn't too skewed and there aren't any *too* extreme points. (Only an issue if you have outliers and a small number of observations thanks to the [central limit theorem](https://en.wikipedia.org/wiki/Central_limit_theorem)). 5. **Fixed and Known** predictors --- ### What do we lose when violated? 1. **Linearity** violated - Linear regression line does not pick up actual conditional expectation. As a linear approximation, the results will be sensitive to the particular `\(x\)` sampled. 2. **Independence** violated - Linear regression line is unbiased, but standard errors can be badly off. 3. **Equal Variance** violated - Linear regression line is unbiased, but standard errors are off. Your `\(p\)`-values may be too small, or too large. 4. **Normality** violated - Only an issue if your sample size is "small". Unstable results if outliers are present. Your `\(p\)`-values may be too small, or too large. 5. **Fixed and Known** violated if there is measurement error in the predictors Q: Do we need (distributional) assumptions on the explanatory variables (the `\(x_i\)`'s)? --- ## Evaluating Independence - Think about the problem. - Were different responses measured on the same observational/experimental unit? - Were data collected in groups? - Non-independence: The temperature today and the temperature tomorrow. If it is warm today, it is probably warm tomorrow. - Non-independence: You are measuring properties of 500 single cells isolated from 3 mice. Because the cells within a given mouse are probably similar, each cell is not independent. --- ### xkcd 2533 <img src="l12/img/slope_hypothesis_testing_2x.png" width="60%" style="display: block; margin: auto;" /> Via [xkcd](https://xkcd.com/2533/) --- class: code50 ### Evaluating other assumptions via residual diagnostics - Obtain the residuals by using `augment()` from broom. They will be the `.resid` variable. ```r aout = augment(lmout) glimpse(aout) ``` ``` ## Rows: 32 ## Columns: 9 ## $ .rownames <chr> "Mazda RX4", "Mazda RX4 Wag", "Datsun 710", … ## $ mpg <dbl> 21.0, 21.0, 22.8, 21.4, 18.7, 18.1, 14.3, 24… ## $ logwt <dbl> 0.9631743, 1.0560527, 0.8415672, 1.1678274, … ## $ .fitted <dbl> 22.79984, 21.21293, 24.87761, 19.30316, 18.1… ## $ .resid <dbl> -1.79984305, -0.21293388, -2.07760886, 2.096… ## $ .hat <dbl> 0.03929578, 0.03263072, 0.05636910, 0.031929… ## $ .sigma <dbl> 2.693624, 2.714823, 2.685917, 2.686126, 2.71… ## $ .cooksd <dbl> 9.677220e-03, 1.109293e-04, 1.917252e-02, 1.… ## $ .std.resid <dbl> -0.68787926, -0.08110004, -0.80118930, 0.798… ``` --- ### Quiz Generate plots: - residuals `\(r_i\)` vs `\(\hat{y}_i\)`. - residuals `\(r_i\)` vs response `\(y_i\)` - residuals `\(r_i\)` vs one continuous and one categorical explanatory variables. --- ### Potential remedies 1. Linearity Violated: Try a transformation. If the relationship looks curved and monotone (i.e. either always increasing or always decreasing) then try a log transformation. 2. Independence Violated: Try a two-step procedure (that collapses over repeated measures) or use longitudinal data analysis. 3. Equal Variance Violated: If the relationship is also curved and monotone, try a log transformation on the response variable. Use a bootstrap. Or stay tuned until you learn about sandwich estimation. 4. Normality Violated: Don't trust prediction intervals (confidence intervals are fine). 5. Fixed and Known Violated: fit measurement error models --- ### Example 1: A perfect residual plot .panelset[ .panel[.panel-name[Fit] <img src="l13-multivariate-lm_files/figure-html/unnamed-chunk-17-1.png" width="60%" style="display: block; margin: auto;" /> ] .panel[.panel-name[Residual] <img src="l13-multivariate-lm_files/figure-html/unnamed-chunk-18-1.png" width="60%" style="display: block; margin: auto;" /> ] ] --- ### Verify - ✅ Means are straight lines - ✅ Residuals seem to be centered at 0 for all `\(x\)` - ✅ Variance looks equal for all `\(x\)` - ✅ Everything looks perfect (too perfect...🤔) --- ### Example 2: Curved Monotone Relationship, Equal Variances - Simulate data: ```r set.seed(1) x = rexp(100) x = x - min(x) + 0.5 y = log(x) * 20 + rnorm(100, sd = 4) (df_fake = tibble(x, y)) ``` ``` ## # A tibble: 100 × 2 ## x y ## <dbl> <dbl> ## 1 1.22 0.419 ## 2 1.64 12.3 ## 3 0.608 -9.46 ## 4 0.603 -11.3 ... ``` --- .panelset[ .panel[.panel-name[Fit] <img src="l13-multivariate-lm_files/figure-html/unnamed-chunk-20-1.png" width="60%" style="display: block; margin: auto;" /> ] .panel[.panel-name[Residual] <img src="l13-multivariate-lm_files/figure-html/unnamed-chunk-21-1.png" width="60%" style="display: block; margin: auto;" /> ] ] --- ### Check - ❌ Curved (but always increasing) relationship between `\(x\)` and `\(y\)`. - ✅ Variance looks equal for all `\(x\)` - ❌ Residual plot has a parabolic shape. - **Solution**: These indicate a `\(\log\)` transformation of `\(x\)` could help. ```r df_fake %>% mutate(logx = log(x)) -> df_fake lm_fake = lm(y ~ logx, data = df_fake) ``` --- ## Quiz Plot the fitted values vs residuals from `lm_fake` after log transformation on `x`. --- ### Example 3: Curved Non-monotone Relationship, Equal Variances - Simulate data: ```r set.seed(1) x = rnorm(100) y = -x^2 + rnorm(100) df_fake = tibble(x, y) ``` --- .panelset[ .panel[.panel-name[Fit] <img src="l13-multivariate-lm_files/figure-html/unnamed-chunk-24-1.png" width="60%" style="display: block; margin: auto;" /> ] .panel[.panel-name[Residual] <img src="l13-multivariate-lm_files/figure-html/unnamed-chunk-25-1.png" width="60%" style="display: block; margin: auto;" /> ] ] --- ### Verify - ❌ Curved relationship between `\(x\)` and `\(y\)` - ❌ Sometimes the relationship is increasing, sometimes it is decreasing. - ✅ Variance looks equal for all `\(x\)` - ❌ Residual plot has a parabolic form. - **Solution**: Include a squared term in the model (or use a gam or spline) ```r lmout = lm(y ~ I(x^2), data = df_fake) ``` --- ### Example 4: Curved Relationship, Variance Increases with `\(Y\)` - Simulate data: ```r set.seed(1) x = rnorm(100) y = exp(x + rnorm(100, sd = 1/2)) df_fake = tibble(x, y) ``` --- .panelset[ .panel[.panel-name[Fit] <img src="l13-multivariate-lm_files/figure-html/unnamed-chunk-28-1.png" width="60%" style="display: block; margin: auto;" /> ] .panel[.panel-name[Residual] <img src="l13-multivariate-lm_files/figure-html/unnamed-chunk-29-1.png" width="60%" style="display: block; margin: auto;" /> ] ] --- ### Verify - ❌ Curved relationship between `\(x\)` and `\(y\)` - ❌ Variance looks like it increases as `\(y\)` increases - ❌ Residual plot has a parabolic form. - ❌ Residual plot variance looks larger to the right and smaller to the left. - **Solution**: Take a log-transformation of `\(y\)`. ```r df_fake %>% mutate(logy = log(y)) -> df_fake lm_fake = lm(logy ~ x, data = df_fake) ``` --- ### Example 5: Linear Relationship, Equal Variances, Skewed Distribution Simulate data: ```r set.seed(1) x = runif(200) y = 15 * x + rexp(200, 0.2) ``` --- .panelset[ .panel[.panel-name[Fit] <img src="l13-multivariate-lm_files/figure-html/unnamed-chunk-32-1.png" width="60%" style="display: block; margin: auto;" /> ] .panel[.panel-name[Residual] <img src="l13-multivariate-lm_files/figure-html/unnamed-chunk-33-1.png" width="60%" style="display: block; margin: auto;" /> ] ] --- ### Verify - ✅ Straight line relationship between `\(x\)` and `\(y\)`. - ✅ Variances about equal for all `\(x\)` - ❌ Skew for all `\(x\)` - ❌ Residual plots show skew. - **Solution**: Do nothing, but report skew, or use a bootstrap or robust standard errors --- ### Example 6: Linear Relationship, Unequal Variances - Simulate data: ```r set.seed(1) x = runif(100) * 10 y = 0.85 * x + rnorm(100, sd = (x - 5) ^ 2) df_fake = tibble(x, y) ``` --- .panelset[ .panel[.panel-name[Fit] <img src="l13-multivariate-lm_files/figure-html/unnamed-chunk-35-1.png" width="60%" style="display: block; margin: auto;" /> ] .panel[.panel-name[Residuals] <img src="l13-multivariate-lm_files/figure-html/unnamed-chunk-36-1.png" width="60%" style="display: block; margin: auto;" /> ] ] --- ### Verify - ✅ Linear relationship between `\(x\)` and `\(y\)`. - ❌ Variance is different for different values of `\(x\)`. - **Solution**: The modern solution is bootstrap or use sandwich estimates of the standard errors ```r rob_fit = estimatr::lm_robust(y ~ x, data = df_fake) tidy(rob_fit) ``` ``` ## term estimate std.error statistic p.value ## 1 (Intercept) -2.861944 2.8353903 -1.009365 0.315285178 ## 2 x 1.368002 0.5167311 2.647416 0.009452483 ## conf.low conf.high df outcome ## 1 -8.4886842 2.764795 98 y ## 2 0.3425659 2.393438 98 y ``` --- # Summary of R commands - `augment()`: - Residuals `\(r_i = y_i - \hat{y}_i\)`: `$.resid` - Fitted Values `\(\hat{y}_i\)`: `$.fitted` - `tidy()`: - Name of variables: `$term` - Coefficient Estimates: `$estimate` - Standard Error (standard deviation of sampling distribution of coefficient estimates): `$std.error` - t-statistic: `$statistic` - p-value: `$p.value` - `glance()`: - R-squared value (proportion of variance explained by regression line, higher is better): `$r.squared` - AIC (lower is better): `$AIC` - BIC (lower is better): `$BIC` --- class: middle # Interpreting coefficients when you log --- ### Log `x` Generally, when you use logs, you interpret associations on a *multiplicative* scale instead of an *additive* scale. No log: - Model: `\(E[y_i] = \beta_0 + \beta_1 x_i\)` - Observations that differ by 1 unit in `\(x\)` tend to differ by `\(\beta_1\)` units in `\(y\)`. Log `\(x\)`: - Model: `\(E[y_i] = \beta_0 + \beta_1 \log_2(x_i)\)` - Observations that are twice as large in `\(x\)` tend to differ by `\(\beta_1\)` units in `\(y\)`. --- ### log `y` Log `\(y\)`: - Model: `\(E[\log_2(y_i)] = \beta_0 + \beta_1 x_i\)` - Observations that differ by 1 unit in `\(x\)` tend to be `\(2^{\beta_1}\)` times larger in `\(y\)`. Log both: - Model: `\(E[\log_2(y_i)] = \beta_0 + \beta_1 \log_2(x_i)\)` - Observations that are twice as large in `\(x\)` tend to be `\(2^{\beta_1}\)` times larger in `\(y\)`. .footnote[Note: we commit statistical abuse here, since `\(\exp \left[ \text{E}(\log(Y) | X) \right] \neq \text{E}(Y | X)\)`, ie, `exp` doesn't commute through the expectation. Though the delta method says this is the 1st order approximation. ] --- class: middle .hand[Interpreting interaction models] --- ## When it doubt, predict it out! With interaction models, it's easy to trick yourself. But you can also make predictions to check your understanding. ```r only_wt = lm(mpg ~ log(wt), mtcars) only_disp = lm(mpg ~ log(disp), mtcars) only_cyl = lm(mpg ~ cyl, mtcars) complicated_model = lm(mpg ~ (log(wt) + log(disp))*cyl, mtcars) ``` --- ```r GGally::ggcoef_compare(list(only_wt = only_wt, only_disp = only_disp, only_cyl = only_cyl, complicated=complicated_model)) ``` <img src="l13-multivariate-lm_files/figure-html/unnamed-chunk-39-1.png" width="80%" style="display: block; margin: auto;" /> --- ### Quiz Why is `log(disp):cyl6` and `log(disp):cyl8` positive? - Check coefficients for `log(disp)`, `cyl6`, and `cyl8`. --- ### Predict it out! .panelset[ .panel[.panel-name[Code] ```r (newdf = expand.grid(wt = mean(mtcars$wt), disp = mean(mtcars$disp), cyl = factor(c(4, 6, 8)))) ``` ``` ## wt disp cyl ## 1 3.21725 230.7219 4 ## 2 3.21725 230.7219 6 ## 3 3.21725 230.7219 8 ``` ```r newdf = newdf %>% mutate(mpg = predict(complicated_model, across())) ggplot(mtcars, aes(x = cyl, y = mpg)) + geom_boxplot() + geom_line(data =newdf, aes(group = 1), color = 'red') ``` ] .panel[.panel-name[Plot] <img src="l13-multivariate-lm_files/figure-html/unnamed-chunk-40-1.png" width="70%" style="display: block; margin: auto;" /> ] ] --- ### We are extrapolating! ```r ggplot(mtcars, aes(x = log(disp), y = mpg)) + geom_point() + facet_wrap(~cyl) + geom_smooth(method = 'lm') + geom_point(data =newdf, aes(group = 1), color = 'red') ``` <img src="l13-multivariate-lm_files/figure-html/unnamed-chunk-41-1.png" width="60%" style="display: block; margin: auto;" /> --- # Acknowledgments Adapted from David Gerard's [Stat 512](https://data-science-master.github.io/lectures/05_linear_regression/05_simple_linear_regression.html)