# non-linear second axis in matplotlib

• Last Update :
• Techknowledgy :

Since the question explicitely asks for arbitrary relation between the two axes (or refuses to clarify), here is a code that plots an arbitrary relation.

import matplotlib.pyplot as plt
import numpy as np

a, b = (2 * np.random.rand(2) - 1) * np.random.randint(1, 500, size = 2)
time = lambda T: a * T + b
Temp = lambda t: (t - b) / a

T = np.linspace(0, 100, 301)
y = T ** 2

fig, ax = plt.subplots()

ax.set_xlabel("Temperature")
ax.plot(T, y)

ax2 = ax.secondary_xaxis(-0.2, functions = (time, Temp))
ax2.set_xlabel("Time")

plt.show()

Suggestion : 2

Based on this stack overflow post, a non-linear invertible function can be used to define a secondary PyPlot axis scale. Two questions, if you will:,However, PyPlot does not seem to consume Julia’s interpolation functions, issuing error:,@Ralph_Smith, thank you very much indeed! I would not have figured this one out, as everything else was working fine without broadcasting those functions. Brilliant.,As explained in this matplotlib documentation, sometimes we want to relate the axes in a transform that is derived empirically. We can then set the forward and inverse transform functions to be linear interpolations between the two datasets.

NOTE: the simple example is just for illustration purposes

using PyPlot
# For medium with a velocity gradient: V(z) = k * z + V0
# Invertible
function defines secondary axis in PyPlot plot
Depth2Time(depth) = @.log(k * depth + V0) / k - log(V0) / k
Time2Depth(time) = @.exp(k * (time + log(V0) / k)) / k - V0 / k

const k = 1.5;
# 1 / s
const V0 = 1500;
# m / s

depth = 0: 2: 2000 # meter
yz = @.sin(2 π * 10 * k * depth / V0) * exp(-100 * k ^ 2 * (depth - 500) ^ 2 / V0 ^ 2)

fig, ax = plt.subplots()

ax.set_xlabel("Depth [m]", color =: blue)
ax.tick_params(axis = "x", colors = "blue")
ax.set_ylabel("Amplitude")
ax.plot(depth, yz)

# THIS IS THE KEY PYPLOT LINE:
ax2 = ax.secondary_xaxis(-0.2, functions = (Depth2Time, Time2Depth), color =: red)
ax2.set_xlabel("Time [s]")
plt.show()

However, PyPlot does not seem to consume Julia’s interpolation functions, issuing error:

RuntimeError: < PyCall.jlwrap(in a Julia
function called from Python)
JULIA: MethodError: no method matching(::Interpolations.GriddedInterpolation {
Float64,
1,
Float64,
Gridded {
Linear
},
Tuple {
Vector {
Float64
}
}
})(::Matrix {
Float64
})

See MWE below:

using PyPlot, Interpolations

N = 200;
velocity0 = rand(1500: 3000, N - 1) # interval velocity[m / s]
depth0 = collect(LinRange(0, 2, N)) # depth[km]
time0 = [0;cumsum(1000 * diff(depth0). / velocity0)] # time[s]

itp_fwd = interpolate((time0, ), depth0, Gridded(Linear()))
itp_inv = interpolate((depth0, ), time0, Gridded(Linear()))

Time2Depth(time) = itp_fwd(time) # s to km
Depth2Time(depth) = itp_inv(depth) # km to s

time = collect(0: 0.002: 0.8)
Depth2Time(Time2Depth(time))≈ time # checks that inverse transforms work fine

yt = @.sin(2 π * 30 * time) * exp(-900 * (time - 0.3) ^ 2)

fig, ax = plt.subplots()

ax.set_xlabel("Time [s]", color =: blue)
ax.tick_params(axis = "x", colors = "blue")
ax.set_ylabel("Amplitude")
ax.plot(time, yt)
plt.xlim(extrema(time))
# Follow key PyPlot command fails with functions from interpolations.jl:
ax2 = ax.secondary_xaxis(-0.2, functions = (Time2Depth, Depth2Time), color =: red)

ax2.set_xlabel("Depth [km]")
plt.show()

The error message suggests that you just need to broadcast:

Time2Depth(time) = itp_fwd.(time) # s to km
Depth2Time(depth) = itp_inv.(depth) # km to s

Suggestion : 3

Here is the case of converting from wavenumber to wavelength in a log-log scale.,In this case, the xscale of the parent is logarithmic, so the child is made logarithmic as well.,In the specific case of the numpy linear interpolation, numpy.interp, this condition can be arbitrarily enforced by providing optional keyword arguments left, right such that values outside the data range are mapped well outside the plot limits.,In order to properly handle the data margins, the mapping functions (forward and inverse in this example) need to be defined beyond the nominal plot limits.

import matplotlib.pyplot as plt
import numpy as np
import datetime
import matplotlib.dates as mdates
from matplotlib.ticker
import AutoMinorLocator

fig, ax = plt.subplots(constrained_layout = True)
x = np.arange(0, 360, 1)
y = np.sin(2 * x * np.pi / 180)
ax.plot(x, y)
ax.set_xlabel('angle [degrees]')
ax.set_ylabel('signal')
ax.set_title('Sine wave')

return x * np.pi / 180

return x * 180 / np.pi

plt.show()
fig, ax = plt.subplots(constrained_layout = True)
x = np.arange(0.02, 1, 0.02)
np.random.seed(19680801)
y = np.random.randn(len(x)) ** 2
ax.loglog(x, y)
ax.set_xlabel('f [Hz]')
ax.set_ylabel('PSD')
ax.set_title('Random spectrum')

def one_over(x):
""
"Vectorized 1/x, treating x==0 manually"
""
x = np.array(x).astype(float)
near_zero = np.isclose(x, 0)
x[near_zero] = np.inf
x[~near_zero] = 1 / x[~near_zero]
return x

# the
function "1/x"
is its own inverse
inverse = one_over

secax = ax.secondary_xaxis('top', functions = (one_over, inverse))
secax.set_xlabel('period [s]')
plt.show()
fig, ax = plt.subplots(constrained_layout = True)
xdata = np.arange(1, 11, 0.4)
ydata = np.random.randn(len(xdata))
ax.plot(xdata, ydata, label = 'Plotted data')

xold = np.arange(0, 11, 0.2)
# fake data set relating x coordinate to another data - derived coordinate.
# xnew must be monotonic, so we sort...
xnew = np.sort(10 * np.exp(-xold / 4) + np.random.randn(len(xold)) / 3)

ax.plot(xold[3: ], xnew[3: ], label = 'Transform data')
ax.set_xlabel('X [m]')
ax.legend()

def forward(x):
return np.interp(x, xold, xnew)

def inverse(x):
return np.interp(x, xnew, xold)

secax = ax.secondary_xaxis('top', functions = (forward, inverse))
secax.xaxis.set_minor_locator(AutoMinorLocator())
secax.set_xlabel('$X_{other}$')

plt.show()
dates = [datetime.datetime(2018, 1, 1) + datetime.timedelta(hours = k * 6)
for k in range(240)
]
temperature = np.random.randn(len(dates)) * 4 + 6.7
fig, ax = plt.subplots(constrained_layout = True)

ax.plot(dates, temperature)
ax.set_ylabel(r '$T\ [^oC]$')
plt.xticks(rotation = 70)

def date2yday(x):
""
"Convert matplotlib datenum to days since 2018-01-01."
""
y = x - mdates.date2num(datetime.datetime(2018, 1, 1))
return y

def yday2date(x):
""
"Return a matplotlib datenum for *x* days after 2018-01-01."
""
y = x + mdates.date2num(datetime.datetime(2018, 1, 1))
return y

secax_x = ax.secondary_xaxis('top', functions = (date2yday, yday2date))
secax_x.set_xlabel('yday [2018]')

def celsius_to_fahrenheit(x):
return x * 1.8 + 32

def fahrenheit_to_celsius(x):
return (x - 32) / 1.8

secax_y = ax.secondary_yaxis(
'right', functions = (celsius_to_fahrenheit, fahrenheit_to_celsius))
secax_y.set_ylabel(r '$T\ [^oF]$')

def celsius_to_anomaly(x):
return (x - np.mean(temperature))

def anomaly_to_celsius(x):
return (x + np.mean(temperature))

# use of a float
for the position:
secax_y2 = ax.secondary_yaxis(
1.2, functions = (celsius_to_anomaly, anomaly_to_celsius))
secax_y2.set_ylabel(r '$T - \overline{T}\ [^oC]$')

plt.show()