Customized Functions#

In this tutorial, we will learn how to create customised functions to reuse them in our project work to return defa;ut traces.

We’ll focus on four key functions: create_line_trace, create_point_trace, create_arrow_trace, create_3d_layout and create_particle_animation. Each of these functions helps us to build different elements of a 3D plot, which we can then combine into a comprehensive visualization.

Note

We are writing these functions to simplify our works. The arguments used here can be pre-defined so no need to repeat them. The default values make our life easier

Setting Up#

First, import the necessary libraries

[28]:
import plotly.graph_objects as go
import numpy as np

Writing Functions#

create_line_trace#

This function creates a line trace between two points in 3D space. We just need to give the starting and end points for quick works

[29]:
def create_line_trace(start, end, color='blue', width=2, name='', dash='solid', showlegend=False):
    line_trace = go.Scatter3d(
        x=[start[0], end[0]],
        y=[start[1], end[1]],
        z=[start[2], end[2]],
        mode='lines',
        line=dict(color=color, width=width, dash=dash),
        name=name,
        showlegend=showlegend
    )
    return line_trace

create_point_trace#

This function creates a trace for a single point in 3D space.

[30]:
def create_point_trace(point, color='red', size=5, name=''):
    x, y, z = point
    trace = go.Scatter3d(
        x=[x], y=[y], z=[z],
        mode='markers+text',
        marker=dict(color=color, size=size),
        text=[name],
        textposition="top center",
        showlegend=False
    )
    return trace

create_arrow_trace#

This function creates an arrow trace to indicate direction or movement.

[31]:
def create_arrow_trace(start, end, color='blue', name='', showlegend=False):
    # Create arrow shaft trace
    shaft_trace = go.Scatter3d(
        x=[start[0], end[0]], y=[start[1], end[1]], z=[start[2], end[2]],
        mode='lines',
        line=dict(color=color, width=5),
        name=name,
        showlegend=showlegend
    )

    # Create arrowhead trace
    vec = np.array(end) - np.array(start)
    length = np.linalg.norm(vec)
    vec_normalized = vec / length if length > 0 else vec
    arrowhead_size = 0.1 * length

    head_trace = go.Cone(
        x=[end[0]], y=[end[1]], z=[end[2]],
        u=[vec_normalized[0]], v=[vec_normalized[1]], w=[vec_normalized[2]],
        sizemode="absolute", sizeref=arrowhead_size, showscale=False,
        anchor="tip", colorscale=[[0, color], [1, color]]
    )

    return [shaft_trace, head_trace]

create_3d_layout#

This function sets up the layout for a 3D plot.

[32]:
def create_3d_layout(title='3D Plot', gridcolor='lightblue', xaxis_title='X Axis', yaxis_title='Y Axis', zaxis_title='Z Axis'):
    layout = go.Layout(
        title=title,
        scene=dict(
            xaxis=dict(showbackground=False, showgrid=True, zeroline=True, zerolinewidth=2, zerolinecolor='rgba(0,0,0,0.5)', showticklabels=True, gridcolor=gridcolor, title=xaxis_title),
            yaxis=dict(showbackground=False, showgrid=True, zeroline=True, zerolinewidth=2, zerolinecolor='rgba(0,0,0,0.5)', showticklabels=True, gridcolor=gridcolor, title=yaxis_title),
            zaxis=dict(showbackground=False, showgrid=True, zeroline=True, zerolinewidth=2, zerolinecolor='rgba(0,0,0,0.5)', showticklabels=True, gridcolor=gridcolor, title=zaxis_title),
        ),
        scene_aspectmode='cube'
    )
    return layout

animate_particle#

Creates animation frames for a particle moving along a curve in a 3D Plotly figure.

[33]:
def animate_particle(curve_points, particle_name='Particle', particle_color='red', particle_size=6, animation_speed=10):
    # Creating frames for animation
    frames = []
    for i in range(0, len(curve_points), animation_speed):
        frame = go.Frame(data=[go.Scatter3d(
            x=[curve_points[i][0]],
            y=[curve_points[i][1]],
            z=[curve_points[i][2]],
            mode='markers+text',
            marker=dict(color=particle_color, size=particle_size),
            text=[particle_name],
            textposition='top center',
            textfont=dict(size=15)
        )])
        frames.append(frame)

    return frames

create_particle_animation#

Creates a 3D animation of a particle moving along a specified path.

[34]:
def create_particle_animation(curve_points, title='Particle Animation', name='P', origin=[0, 0, 0]):

    # Validate curve_points
    if not isinstance(curve_points, np.ndarray):
        raise ValueError("curve_points must be a NumPy array.")
    if curve_points.size == 0 or curve_points.ndim != 2 or curve_points.shape[1] != 3:
        raise ValueError("curve_points must be a non-empty 2D NumPy array with 3 columns (x, y, z coordinates).")

    traces = []
    frames = []

    x_vals, y_vals, z_vals = zip(*curve_points)

    # Add initial point and origin 'O'
    traces.append(create_point_trace(curve_points[0], color='green', size=8, name=name))
    traces.append(create_point_trace(origin, color='black', size=3, name='O'))

    # Set layout for the figure
    layout = create_3d_layout(title=title, xaxis_title='X Axis', yaxis_title='Y Axis', zaxis_title='Z Axis')

    # Add path trace
    traces.append(go.Scatter3d(
        x=x_vals, y=y_vals, z=z_vals,
        mode="lines",
        line=dict(color="blue", width=2),
        name='Path'
    ))

    # Animate the particle
    frames = animate_particle(curve_points, name, particle_color='green', particle_size=8, animation_speed=1)

    fig = go.Figure(data=traces, layout=layout, frames=frames)

    # Adjust the camera settings
    fig.update_layout(
        scene=dict(
            camera=dict(
                up=dict(x=0, y=0, z=1),  # Sets the up direction (in this case, the z-axis is up)
                center=dict(x=0, y=0, z=0),  # Centers the view on the given coordinates
                eye=dict(x=1, y=-1.25, z=1.25)  # Sets the position of the camera
            ),
            aspectmode='cube'  # Keeps the aspect ratio of the axes fixed
        )
    )

    # Add play and pause buttons
    fig.update_layout(
        updatemenus=[
            dict(
                type="buttons",
                buttons=[
                    dict(label="Play",
                          method="animate",
                          args=[None, dict(frame=dict(duration=50, redraw=True), fromcurrent=True)]),
                    dict(label="Pause",
                          method="animate",
                          args=[[None], dict(frame=dict(duration=0, redraw=False), mode="immediate")])
                ]
            )
        ]
    )

    return fig

Docs and Errors#

It is always best to keep error handling and good documentations for your function is great. For the simplicity I am not going to use it here but you may do it with some AI tools like ChatGPT or Bard quickly.

Here is an exaple for a function with error handling and docstring.

[35]:
def create_arrow_trace(start, end, color='blue', name='', showlegend=False):
    """
    Creates traces representing an arrow for a Plotly 3D plot.

    Parameters:
    - start (list or array): The starting point of the arrow, specified as [x, y, z].
    - end (list or array): The ending point (tip) of the arrow, specified as [x, y, z].
    - color (str, optional): Color of the arrow. Defaults to 'blue'.
    - name (str, optional): Name of the arrow, used for legend and hover. Defaults to an empty string.
    - showlegend (bool, optional): Whether to show the legend entry for this arrow. Defaults to False.

    Returns:
    - list: A list containing the Scatter3d trace for the arrow shaft and the Cone trace for the arrowhead.

    Raises:
    - ValueError: If 'start' or 'end' are not lists or arrays with exactly three elements, or if they contain non-numerical values.
    - ValueError: If 'color' is not a string.
    - ValueError: If 'name' is not a string.
    - ValueError: If 'showlegend' is not a boolean.
    """

    # Validate start and end points
    for point in [start, end]:
        if not isinstance(point, (list, np.ndarray)) or len(point) != 3:
            raise ValueError("The 'start' and 'end' parameters must be lists or arrays with exactly three elements (x, y, z).")
        try:
            # Convert elements to float to ensure they are numerical
            map(float, point)
        except ValueError:
            raise ValueError("The 'start' and 'end' points must contain numerical values.")

    # Validate other parameters
    if not isinstance(color, str):
        raise ValueError("The 'color' parameter must be a string.")
    if not isinstance(name, str):
        raise ValueError("The 'name' parameter must be a string.")
    if not isinstance(showlegend, bool):
        raise ValueError("The 'showlegend' parameter must be a boolean.")

    # Create arrow shaft trace
    shaft_trace = go.Scatter3d(
        x=[start[0], end[0]], y=[start[1], end[1]], z=[start[2], end[2]],
        mode='lines',
        line=dict(color=color, width=5),
        name=name,
        showlegend=showlegend
    )

    # Create arrowhead trace
    vec = np.array(end) - np.array(start)
    length = np.linalg.norm(vec)
    vec_normalized = vec / length if length > 0 else vec

    # Size of the arrowhead relative to the arrow length
    arrowhead_size = 0.1 * length

    head_trace = go.Cone(
        x=[end[0]], y=[end[1]], z=[end[2]],
        u=[vec_normalized[0]], v=[vec_normalized[1]], w=[vec_normalized[2]],
        sizemode="absolute", sizeref=arrowhead_size, showscale=False,
        anchor="tip", colorscale=[[0, color], [1, color]]
    )

    return [shaft_trace, head_trace]

Combining the Elements#

Now, let’s use these functions to create a simple 3D plot.

Static Plot#

[36]:
# Define start and end points for the line and arrow
start_point = [0, 0, 0]
end_point = [1, 0, 4]

# Create traces
line_trace = create_line_trace(end_point, [1, 0, 0], dash='dash', color='red')
point_trace = create_point_trace(end_point, color='green')
arrow_trace = create_arrow_trace(start_point, end_point)

# Create layout
layout = create_3d_layout(title='My 3D Visualization')

# Combine everything into a figure
fig = go.Figure(data=[line_trace, point_trace] + arrow_trace, layout=layout)

# Show the plot
fig.show()

Animation#

[37]:
# Example parameters
initial_position = np.array([0, 0, -1])
velocity = np.array([1, 0, 0])
time_values = np.linspace(0, 2, 11)

# Initialize an array to store curve points at different time steps
curve_points = np.zeros((len(time_values), 3))

# Generate curve points based on the given formula
for i, time in enumerate(time_values):
    curve_points[i, :] = time * (time * initial_position - velocity)

# Create a particle animation using the generated curve points
animation_figure = create_particle_animation(curve_points)

# Show the animation
animation_figure.show()

Feel free to expand upon this tutorial by adding more complex elements, adjusting parameters, or incorporating these plots into your data analysis workflow. Happy plotting!