Multi-Fidelity
The following example is a 1D regression multi-fidelity issue. Blue points are low fidelity observations and orange points are high fidelity observations. The goal is to use both datasets to learn the task on the [0, 1] interval.
To tackle this challenge, we use here the parameter-based method: RegularTransferNN
[1]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from sklearn.metrics import mean_absolute_error, mean_squared_error
import tensorflow as tf
from tensorflow.keras import Model, Sequential
from tensorflow.keras.optimizers import Adam, SGD, RMSprop, Adagrad
from tensorflow.keras.layers import Dense, Input, Dropout, Conv2D, MaxPooling2D, Flatten, Reshape, GaussianNoise, BatchNormalization
from tensorflow.keras.constraints import MinMaxNorm
from tensorflow.keras.regularizers import l2
from tensorflow.keras.callbacks import Callback
from tensorflow.keras.models import clone_model
from adapt.parameter_based import RegularTransferNN
Setup
[2]:
np.random.seed(0)
Xs = np.linspace(0, 1, 200)
ys = (1 - Xs**2) * np.sin(2 * 2 * np.pi * Xs) - Xs + 0.1 * np.random.randn(len(Xs))
Xt = Xs[:100]
yt = (1 - Xt**2) * np.sin(2 * 2 * np.pi * Xt) - Xt - 1.5
gt = (1 - Xs**2) * np.sin(2 * 2 * np.pi * Xs) - Xs - 1.5
[3]:
plt.figure(figsize=(10,6))
plt.plot(Xs, ys, '.', label="low fidelity", ms=15, alpha=0.9, markeredgecolor="black")
plt.plot(Xt, yt, '.', label="high fidelity", ms=15, alpha=0.9, markeredgecolor="black")
plt.plot(Xs, gt, c="black", alpha=0.7, ls="--", label="Ground truth")
plt.legend(fontsize=14)
plt.xlabel("X", fontsize=16)
plt.ylabel("y = f(X)", fontsize=16)
plt.show()
Network
[4]:
np.random.seed(0)
tf.random.set_seed(0)
model = Sequential()
model.add(Dense(100, activation='relu', input_shape=(1,)))
model.add(Dense(100, activation='relu'))
model.add(Dense(1))
model.compile(optimizer=Adam(0.001), loss='mean_squared_error')
Low fidelity only
[5]:
np.random.seed(0)
tf.random.set_seed(0)
model_low = clone_model(model)
model_low.compile(optimizer=Adam(0.001), loss='mean_squared_error')
model_low.fit(Xs, ys, epochs=800, batch_size=34, verbose=0);
[6]:
yp = model_low.predict(Xs.reshape(-1,1))
score = mean_absolute_error(gt.ravel(), yp.ravel())
plt.figure(figsize=(10,6))
plt.plot(Xs, ys, '.', label="low fidelity", ms=15, alpha=0.9, markeredgecolor="black")
plt.plot(Xt, yt, '.', label="high fidelity", ms=15, alpha=0.9, markeredgecolor="black")
plt.plot(Xs, gt, c="black", alpha=0.7, ls="--", label="Ground truth")
plt.plot(Xs, yp, c="red", alpha=0.9, lw=3, label="Predictions")
plt.legend(fontsize=14)
plt.xlabel("X", fontsize=16)
plt.ylabel("y = f(X)", fontsize=16)
plt.title("Low Fidelity Only -- MAE = %.3f"%score, fontsize=18)
plt.show()
High fidelity only
[7]:
np.random.seed(0)
tf.random.set_seed(0)
model_high = clone_model(model)
model_high.compile(optimizer=Adam(0.001), loss='mean_squared_error')
model_high.fit(Xt, yt, epochs=800, batch_size=34, verbose=0);
[8]:
yp = model_high.predict(Xs.reshape(-1,1))
score = mean_absolute_error(gt.ravel(), yp.ravel())
plt.figure(figsize=(10,6))
plt.plot(Xs, ys, '.', label="low fidelity", ms=15, alpha=0.9, markeredgecolor="black")
plt.plot(Xt, yt, '.', label="high fidelity", ms=15, alpha=0.9, markeredgecolor="black")
plt.plot(Xs, gt, c="black", alpha=0.7, ls="--", label="Ground truth")
plt.plot(Xs, yp, c="red", alpha=0.9, lw=3, label="Predictions")
plt.legend(fontsize=14)
plt.xlabel("X", fontsize=16)
plt.ylabel("y = f(X)", fontsize=16)
plt.title("Low Fidelity Only -- MAE = %.3f"%score, fontsize=18)
plt.show()
RegularTransferNN
[10]:
model_reg = RegularTransferNN(model_low, lambdas=1000., random_state=1, optimizer=Adam(0.0001))
model_reg.fit(Xt.reshape(-1,1), yt, epochs=1200, batch_size=34, verbose=0);
WARNING:tensorflow:Gradients do not exist for variables ['Variable:0'] when minimizing the loss.
WARNING:tensorflow:Gradients do not exist for variables ['Variable:0'] when minimizing the loss.
[11]:
yp = model_reg.predict(Xs.reshape(-1,1))
score = mean_absolute_error(gt.ravel(), yp.ravel())
plt.figure(figsize=(10,6))
plt.plot(Xs, ys, '.', label="low fidelity", ms=15, alpha=0.9, markeredgecolor="black")
plt.plot(Xt, yt, '.', label="high fidelity", ms=15, alpha=0.9, markeredgecolor="black")
plt.plot(Xs, gt, c="black", alpha=0.7, ls="--", label="Ground truth")
plt.plot(Xs, yp, c="red", alpha=0.9, lw=3, label="Predictions")
plt.legend(fontsize=14)
plt.xlabel("X", fontsize=16)
plt.ylabel("y = f(X)", fontsize=16)
plt.title("Low Fidelity Only -- MAE = %.3f"%score, fontsize=18)
plt.show()