TensorFlow.js and linear regression

TensorFlow.js is a JavaScript machine learning library, part of the larger TensorFlow ecosystem used to build ML-powered applications. It packages a set of functionalities for training and deploying machine learning and deep learning models and runs both in a web browser and in the Node.js environment.

Simple linear regression is one of the most important types of data analysis which models the relationship between two variables. A special case of supervised learning, it is one of the simplest tasks that can be performed using TensorFlow.js.

A regular ML-based solution typically includes the following steps:

  • Load and prepare the input data.
  • Define the model architecture.
  • Compile and training the model.
  • Use the model to make predictions.

In a TensorFlow.js-based application, those steps correspond to the following sections of the code:

const { inputs, labels } = await getData()
const model = createModel()
model.compile(...)
await model.fit(inputs, labels, ...)
model.predict(...)

In the case of simple linear regression, the model architecture can be defined as follows:

const model = tf.sequential()
model.add(tf.layers.dense({ inputShape: [1], units: 1, useBias: true }))
model.add(tf.layers.dense({ units: 1, useBias: true }))

Example: Predicting wine alcohol content from density

The Wine Quality Data Set is a popular dataset describing vinho verde wine variants. It contains various features such as acidity, density, alcohol content, etc. about each given wine sample.

One of the observations that can be made from the dataset is that alcohol and density are negatively correlated.

See also

Made by Anton Vasetenkov.

If you want to say hi, you can reach me on LinkedIn or via email. If you like my work, you can support me by buying me a coffee.