Customized Functions#
In this tutorial, we will learn how to create customised functions to reuse them in our project work to return defa;ut traces.
We’ll focus on four key functions: create_line_trace
, create_point_trace
, create_arrow_trace
, create_3d_layout
and create_particle_animation
. Each
of these functions helps us to build different elements of a 3D plot, which we can then combine into a
comprehensive visualization.
Note
We are writing these functions to simplify our works. The arguments used here can be pre-defined so no need to repeat them. The default values make our life easier
Setting Up#
First, import the necessary libraries
[28]:
import plotly.graph_objects as go
import numpy as np
Writing Functions#
create_line_trace
#
This function creates a line trace between two points in 3D space. We just need to give the starting and end points for quick works
[29]:
def create_line_trace(start, end, color='blue', width=2, name='', dash='solid', showlegend=False):
line_trace = go.Scatter3d(
x=[start[0], end[0]],
y=[start[1], end[1]],
z=[start[2], end[2]],
mode='lines',
line=dict(color=color, width=width, dash=dash),
name=name,
showlegend=showlegend
)
return line_trace
create_point_trace
#
This function creates a trace for a single point in 3D space.
[30]:
def create_point_trace(point, color='red', size=5, name=''):
x, y, z = point
trace = go.Scatter3d(
x=[x], y=[y], z=[z],
mode='markers+text',
marker=dict(color=color, size=size),
text=[name],
textposition="top center",
showlegend=False
)
return trace
create_arrow_trace
#
This function creates an arrow trace to indicate direction or movement.
[31]:
def create_arrow_trace(start, end, color='blue', name='', showlegend=False):
# Create arrow shaft trace
shaft_trace = go.Scatter3d(
x=[start[0], end[0]], y=[start[1], end[1]], z=[start[2], end[2]],
mode='lines',
line=dict(color=color, width=5),
name=name,
showlegend=showlegend
)
# Create arrowhead trace
vec = np.array(end) - np.array(start)
length = np.linalg.norm(vec)
vec_normalized = vec / length if length > 0 else vec
arrowhead_size = 0.1 * length
head_trace = go.Cone(
x=[end[0]], y=[end[1]], z=[end[2]],
u=[vec_normalized[0]], v=[vec_normalized[1]], w=[vec_normalized[2]],
sizemode="absolute", sizeref=arrowhead_size, showscale=False,
anchor="tip", colorscale=[[0, color], [1, color]]
)
return [shaft_trace, head_trace]
create_3d_layout
#
This function sets up the layout for a 3D plot.
[32]:
def create_3d_layout(title='3D Plot', gridcolor='lightblue', xaxis_title='X Axis', yaxis_title='Y Axis', zaxis_title='Z Axis'):
layout = go.Layout(
title=title,
scene=dict(
xaxis=dict(showbackground=False, showgrid=True, zeroline=True, zerolinewidth=2, zerolinecolor='rgba(0,0,0,0.5)', showticklabels=True, gridcolor=gridcolor, title=xaxis_title),
yaxis=dict(showbackground=False, showgrid=True, zeroline=True, zerolinewidth=2, zerolinecolor='rgba(0,0,0,0.5)', showticklabels=True, gridcolor=gridcolor, title=yaxis_title),
zaxis=dict(showbackground=False, showgrid=True, zeroline=True, zerolinewidth=2, zerolinecolor='rgba(0,0,0,0.5)', showticklabels=True, gridcolor=gridcolor, title=zaxis_title),
),
scene_aspectmode='cube'
)
return layout
animate_particle
#
Creates animation frames for a particle moving along a curve in a 3D Plotly figure.
[33]:
def animate_particle(curve_points, particle_name='Particle', particle_color='red', particle_size=6, animation_speed=10):
# Creating frames for animation
frames = []
for i in range(0, len(curve_points), animation_speed):
frame = go.Frame(data=[go.Scatter3d(
x=[curve_points[i][0]],
y=[curve_points[i][1]],
z=[curve_points[i][2]],
mode='markers+text',
marker=dict(color=particle_color, size=particle_size),
text=[particle_name],
textposition='top center',
textfont=dict(size=15)
)])
frames.append(frame)
return frames
create_particle_animation
#
Creates a 3D animation of a particle moving along a specified path.
[34]:
def create_particle_animation(curve_points, title='Particle Animation', name='P', origin=[0, 0, 0]):
# Validate curve_points
if not isinstance(curve_points, np.ndarray):
raise ValueError("curve_points must be a NumPy array.")
if curve_points.size == 0 or curve_points.ndim != 2 or curve_points.shape[1] != 3:
raise ValueError("curve_points must be a non-empty 2D NumPy array with 3 columns (x, y, z coordinates).")
traces = []
frames = []
x_vals, y_vals, z_vals = zip(*curve_points)
# Add initial point and origin 'O'
traces.append(create_point_trace(curve_points[0], color='green', size=8, name=name))
traces.append(create_point_trace(origin, color='black', size=3, name='O'))
# Set layout for the figure
layout = create_3d_layout(title=title, xaxis_title='X Axis', yaxis_title='Y Axis', zaxis_title='Z Axis')
# Add path trace
traces.append(go.Scatter3d(
x=x_vals, y=y_vals, z=z_vals,
mode="lines",
line=dict(color="blue", width=2),
name='Path'
))
# Animate the particle
frames = animate_particle(curve_points, name, particle_color='green', particle_size=8, animation_speed=1)
fig = go.Figure(data=traces, layout=layout, frames=frames)
# Adjust the camera settings
fig.update_layout(
scene=dict(
camera=dict(
up=dict(x=0, y=0, z=1), # Sets the up direction (in this case, the z-axis is up)
center=dict(x=0, y=0, z=0), # Centers the view on the given coordinates
eye=dict(x=1, y=-1.25, z=1.25) # Sets the position of the camera
),
aspectmode='cube' # Keeps the aspect ratio of the axes fixed
)
)
# Add play and pause buttons
fig.update_layout(
updatemenus=[
dict(
type="buttons",
buttons=[
dict(label="Play",
method="animate",
args=[None, dict(frame=dict(duration=50, redraw=True), fromcurrent=True)]),
dict(label="Pause",
method="animate",
args=[[None], dict(frame=dict(duration=0, redraw=False), mode="immediate")])
]
)
]
)
return fig
Docs and Errors#
It is always best to keep error handling and good documentations for your function is great. For the simplicity I am not going to use it here but you may do it with some AI tools like ChatGPT or Bard quickly.
Here is an exaple for a function with error handling and docstring.
[35]:
def create_arrow_trace(start, end, color='blue', name='', showlegend=False):
"""
Creates traces representing an arrow for a Plotly 3D plot.
Parameters:
- start (list or array): The starting point of the arrow, specified as [x, y, z].
- end (list or array): The ending point (tip) of the arrow, specified as [x, y, z].
- color (str, optional): Color of the arrow. Defaults to 'blue'.
- name (str, optional): Name of the arrow, used for legend and hover. Defaults to an empty string.
- showlegend (bool, optional): Whether to show the legend entry for this arrow. Defaults to False.
Returns:
- list: A list containing the Scatter3d trace for the arrow shaft and the Cone trace for the arrowhead.
Raises:
- ValueError: If 'start' or 'end' are not lists or arrays with exactly three elements, or if they contain non-numerical values.
- ValueError: If 'color' is not a string.
- ValueError: If 'name' is not a string.
- ValueError: If 'showlegend' is not a boolean.
"""
# Validate start and end points
for point in [start, end]:
if not isinstance(point, (list, np.ndarray)) or len(point) != 3:
raise ValueError("The 'start' and 'end' parameters must be lists or arrays with exactly three elements (x, y, z).")
try:
# Convert elements to float to ensure they are numerical
map(float, point)
except ValueError:
raise ValueError("The 'start' and 'end' points must contain numerical values.")
# Validate other parameters
if not isinstance(color, str):
raise ValueError("The 'color' parameter must be a string.")
if not isinstance(name, str):
raise ValueError("The 'name' parameter must be a string.")
if not isinstance(showlegend, bool):
raise ValueError("The 'showlegend' parameter must be a boolean.")
# Create arrow shaft trace
shaft_trace = go.Scatter3d(
x=[start[0], end[0]], y=[start[1], end[1]], z=[start[2], end[2]],
mode='lines',
line=dict(color=color, width=5),
name=name,
showlegend=showlegend
)
# Create arrowhead trace
vec = np.array(end) - np.array(start)
length = np.linalg.norm(vec)
vec_normalized = vec / length if length > 0 else vec
# Size of the arrowhead relative to the arrow length
arrowhead_size = 0.1 * length
head_trace = go.Cone(
x=[end[0]], y=[end[1]], z=[end[2]],
u=[vec_normalized[0]], v=[vec_normalized[1]], w=[vec_normalized[2]],
sizemode="absolute", sizeref=arrowhead_size, showscale=False,
anchor="tip", colorscale=[[0, color], [1, color]]
)
return [shaft_trace, head_trace]
Combining the Elements#
Now, let’s use these functions to create a simple 3D plot.
Static Plot#
[36]:
# Define start and end points for the line and arrow
start_point = [0, 0, 0]
end_point = [1, 0, 4]
# Create traces
line_trace = create_line_trace(end_point, [1, 0, 0], dash='dash', color='red')
point_trace = create_point_trace(end_point, color='green')
arrow_trace = create_arrow_trace(start_point, end_point)
# Create layout
layout = create_3d_layout(title='My 3D Visualization')
# Combine everything into a figure
fig = go.Figure(data=[line_trace, point_trace] + arrow_trace, layout=layout)
# Show the plot
fig.show()
Animation#
[37]:
# Example parameters
initial_position = np.array([0, 0, -1])
velocity = np.array([1, 0, 0])
time_values = np.linspace(0, 2, 11)
# Initialize an array to store curve points at different time steps
curve_points = np.zeros((len(time_values), 3))
# Generate curve points based on the given formula
for i, time in enumerate(time_values):
curve_points[i, :] = time * (time * initial_position - velocity)
# Create a particle animation using the generated curve points
animation_figure = create_particle_animation(curve_points)
# Show the animation
animation_figure.show()
Feel free to expand upon this tutorial by adding more complex elements, adjusting parameters, or incorporating these plots into your data analysis workflow. Happy plotting!