Skip to content

multidim_plotter

plot_from_sql(x_tag, y_tag, output, label, exp_id=None)

Plot colormap/3D figure from data in /results.db.

Parameters:

Name Type Description Default
x_tag str

Tag to use as x axis.

required
y_tag str

Tag to use as y axis.

required
output str

String to use as output, needs to correspond to one of the output cols in the db.

required
label str

Figure needs a label.

required
exp_id str

Optional experiment id. If omitted, 'latest_experiment' is used.

None
Source code in emod_api/multidim_plotter.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
def plot_from_sql(x_tag: str,
                  y_tag: str,
                  output: str,
                  label: str,
                  exp_id: str = None):
    """
    Plot colormap/3D figure from data in <experiment_id>/results.db.

    Args:
        x_tag: Tag to use as x axis.
        y_tag: Tag to use as y axis.
        output: String to use as output, needs to correspond to one of the output cols in the db.
        label: Figure needs a label.
        exp_id: Optional experiment id. If omitted, 'latest_experiment' is used.

    """

    fig = plt.figure()

    ax = Axes3D(fig)

    #ax = fig.gca(projection='3d')

    #query = f"select {x_tag}, {y_tag}, avg(output) from results where cast({x_tag} as decimal) > 0.01 group by {x_tag}, {y_tag};"
    #query = f"select {x_tag}, {y_tag}, avg(output), output2 from results group by {x_tag}, {y_tag};"
    if exp_id:
        db = os.path.join( str(exp_id), "results.db" )
    else:
        db = os.path.join( "latest_experiment", "results.db" )
    con = sqlite3.connect( db )
    cur = con.cursor()
    x_tag = x_tag.replace( ' ', '_' ).replace( '-', '_' )
    y_tag = y_tag.replace( ' ', '_' ).replace( '-', '_' )
    #query = f"select {x_tag}, {y_tag}, avg(output) from results where cast(output2 as decimal) > 0.35 and cast(output2 as decimal)<0.45 group by {x_tag}, {y_tag};"
    query = f"select {x_tag}, {y_tag}, avg({output}) from results group by {x_tag}, {y_tag};"
    try:
        cur.execute( query )
        results = cur.fetchall()
    except Exception as ex:
        print( f"Encountered fatal exception {ex} when executing query {query} on db {db}." )
        return

    x = []
    y = []
    z = []
    for result in results:
        x.append( result[0] )
        y.append( result[1] )
        z.append( result[2] )
    surf = ax.plot_trisurf(x, y, z, cmap=cm.jet, linewidth=0.1)
    ax.set_xlabel( f"{x_tag} rate" )
    ax.set_ylabel( f"{y_tag} rate" )
    #ax.set_zlabel( "final prevalence" )
    #ax.set_zlabel( label )
    fig.colorbar(surf, shrink=0.5, aspect=5)
    #X, Y = np.meshgrid(np.array(x), np.array(y)) 
    #ax.plot_surface( X, Y, np.array(z) )
    ax.view_init(elev=90, azim=0)
    plt.title( label )
    plt.show()