PyCharm母公司JetBrains又出可视化神器!

AI算法与图像处理

共 16853字,需浏览 34分钟

 ·

2022-07-13 17:52

本文简单介绍下偶遇的一个不错「python可视化工具lets-plot」,喜欢用R中的ggplot2绘制统计图的小伙伴一定要看看~

a、lets-plot由JetBrains(没错,「和PyCharm同出一家」)开发,主要「参考R语言中的ggplot2」「擅长统计图」,但多了「交互能力」,所以也是基于图层图形语法(the Grammar of Graphics),和之前介绍的plotnine一样,绘图原理可参考之前文章:

b、可在 「IntelliJ IDEA and PyCharm安装 lets-plot插件」 "Lets-Plot in SciView"。

lets-plot安装

pip install lets-plot
或者
conda install lets-plot

lets-plot快速绘图

import numpy as np
from lets_plot import *

LetsPlot.setup_html()  #默认开启Javascript交互模式

#LetsPlot.setup_html(no_js=True)#关闭Javascript交互模式

#数据准备
np.random.seed(12)
data = dict(cond=np.repeat(['A''B'], 200),
            rating=np.concatenate(
                (np.random.normal(01200), np.random.normal(11.5200))))

#绘图
ggplot(data, aes(x='rating', fill='cond')) + ggsize(500250) \
+ geom_density(color='dark_green', alpha=0.7) + scale_fill_brewer(type='seq') \
+ theme(axis_line_y='blank')

可以看出,语法几乎和ggplot2一行,但是,「比ggplot2多了交互」,上图为开启了交互模式,可以鼠标点击查看对应位置数据信息,LetsPlot.setup_html(no_js=True)控制交互模式的开启与关闭。

lets-plot详细介绍

lets-plot所有方法

import lets_plot

print(dir(lets_plot))#输出lets_plot的方法

['Dict', 'GGBunch', 'LetsPlot', 'NO_JS', 'OFFLINE', '__all__', '__builtins__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__path__', '__spec__', '__version__', '_global_settings', '_kbridge', '_settings', '_type_utils', '_version', 'aes', 'arrow', 'bistro', 'cfg', '「coord」_cartesian', 'coord_fixed', 'coord_map', 'element_blank', 'export', 'extend_path', 'facet_grid', 'facet_wrap', 'frontend_context', '「geom」_abline', 'geom_area', 'geom_bar', 'geom_bin2d', 'geom_boxplot', 'geom_contour', 'geom_contourf', 'geom_crossbar', 'geom_density', 'geom_density2d', 'geom_density2df', 'geom_errorbar', 'geom_freqpoly', 'geom_histogram', 'geom_hline', 'geom_image', 'geom_jitter', 'geom_line', 'geom_linerange', 'geom_livemap', 'geom_map', 'geom_path', 'geom_point', 'geom_pointrange', 'geom_polygon', 'geom_raster', 'geom_rect', 'geom_ribbon', 'geom_segment', 'geom_smooth', 'geom_step', 'geom_text', 'geom_tile', 'geom_vline', 'get_global_bool', 'gg_image_matrix', 'ggplot', 'ggsave', 'ggsize', 'ggtitle', 'guide_colorbar', 'guide_legend', 'guides', 'is_production', 'labs', 'layer', 'layer_tooltips', 'lims', 'mapping', 'maptiles_lets_plot', 'maptiles_zxy', 'plot', 'position_dodge', 'position_jitter', 'position_jitterdodge', 'position_nudge', '「sampling」_group_random', 'sampling_group_systematic', 'sampling_pick', 'sampling_random', 'sampling_random_stratified', 'sampling_systematic', 'sampling_vertex_dp', 'sampling_vertex_vw', '「scale」_alpha', 'scale_alpha_identity', 'scale_alpha_manual', 'scale_color_brewer', 'scale_color_continuous', 'scale_color_discrete', 'scale_color_gradient', 'scale_color_gradient2', 'scale_color_grey', 'scale_color_hue', 'scale_color_identity', 'scale_color_manual', 'scale_fill_brewer', 'scale_fill_continuous', 'scale_fill_discrete', 'scale_fill_gradient', 'scale_fill_gradient2', 'scale_fill_grey', 'scale_fill_hue', 'scale_fill_identity', 'scale_fill_manual', 'scale_linetype_identity', 'scale_linetype_manual', 'scale_shape', 'scale_shape_identity', 'scale_shape_manual', 'scale_size', 'scale_size_area', 'scale_size_identity', 'scale_size_manual', 'scale_x_continuous', 'scale_x_datetime', 'scale_x_discrete', 'scale_x_discrete_reversed', 'scale_x_log10', 'scale_x_reverse', 'scale_y_continuous', 'scale_y_datetime', 'scale_y_discrete', 'scale_y_discrete_reversed', 'scale_y_log10', 'scale_y_reverse', 'settings_utils', 'stat_corr', 'theme', 'xlab', 'xlim', 'ylab', 'ylim']

  • 「coord_开头」为坐标系设置方法;
  • 「geom_开头」的为lets-plot支持的图形类别;
  • 「sampling_开头」的为数据变换方法
  • 「sacle_开头」的为标度设置方法;

支持图形类别

结合geopandas绘制地图

相关性图(Correlation Plot)

import pandas as pd
from lets_plot import *
from lets_plot.bistro.corr import *

mpg_df = pd.read_csv('letplot.data//mpg.csv').drop(columns=['Unnamed: 0'])


def group(plots):
    """
    定义拼图函数group
    """

    bunch = GGBunch()
    for idx, p in enumerate(plots):
        x = (idx % 2) * 450
        y = int(idx / 2) * 350
        bunch.add_plot(p, x, y)

    return bunch


group([
    corr_plot(mpg_df).tiles().build() + ggtitle("Tiles"),
    corr_plot(mpg_df).points().build() + ggtitle("Points"),
    corr_plot(mpg_df).tiles().labels().build() + ggtitle("Tiles and labels"),
    corr_plot(mpg_df).points().labels().tiles().build() +
    ggtitle("Tiles, points and labels")
])
housing_df = pd.read_csv("letplot.data//Ames_house_prices_train.csv")

corr_plot(housing_df).tiles().palette_BrBG().build()

图片分面(Image Matrix)

import numpy as np
from lets_plot import *
from lets_plot.bistro.im import image_matrix

from PIL import Image
from io import BytesIO

image = Image.open('./dog.png')#一张旺财的png
img = np.asarray(image)

rows = 2
cols = 3
X = np.empty([rows, cols], dtype=object)
X.fill(img)

分布关系图(distribution plot)

geom_histogram, geom_density, geom_vline, geom_freqpoly, geom_boxplot, geom_histogram

#geom_histogram
from pandas import DataFrame
import numpy as np
from lets_plot import *

LetsPlot.setup_html()

np.random.seed(123)

#数据准备
data = DataFrame(
    dict(cond=np.repeat(['A''B'], 200),
         rating=np.concatenate(
             (np.random.normal(01200), np.random.normal(.81200)))))


#绘直方图
p = ggplot(data, aes(x='rating')) + ggsize(500250)
p + geom_histogram(binwidth=.5)
p + geom_histogram(
    aes(y='..density..'), binwidth=.5, colour="black", fill="white")

「geom_boxplot」

df = pd.read_csv('letplot.data//mpg.csv')

ggplot(df, aes('class''hwy')) + \
    geom_boxplot(tooltips=layer_tooltips().format('^Y''.2f')        # all positionals
                                          .format('^ymax''.3f')     # use number format --> "ymax: value"
                                          .format('^middle''{.3f}'# use line format --> "value"
                                          .format('^ymin''ymin is {.3f}')) + \
    theme(legend_position='none')

误差棒图(errorbar plot)

geom_errorbar, geom_line, geom_point, geom_bar, geom_crossbar, geom_linerange, geom_pointrange

pd = position_dodge(0.1)
data = dict(supp=['OJ''OJ''OJ''VC''VC''VC'],
            dose=[0.51.02.00.51.02.0],
            length=[13.2322.7026.067.9816.7726.14],
            len_min=[11.8321.224.504.2415.2623.35],
            len_max=[15.6324.927.1110.7219.2828.93])

p = ggplot(data, aes(x='dose', color='supp'))
p + geom_errorbar(aes(ymin='len_min', ymax='len_max', group='supp'), color='black', width=.1, position=pd) \
+ geom_line(aes(y='length'), position=pd) \
+ geom_point(aes(y='length'), position=pd, size=5)
p1 = p \
+ xlab("Dose (mg)") \
+ ylab("Tooth length (mm)") \
+ scale_color_manual(['orange''dark_green'], na_value='gray') \
+ ggsize(700400)

p1 \
+ geom_bar(aes(y='length', fill='supp'), stat='identity', position='dodge', color='black') \
+ geom_errorbar(aes(ymin='len_min', ymax='len_max', group='supp'), color='black', width=.1, position=position_dodge(0.9)) \
+ theme(legend_justification=[0,1], legend_position=[0,1])

散点图+趋势线(scatter/smooth plot)

geom_point, geom_smooth (stat_smooth)

mpg_df = pd.read_csv ("letplot.data/mpg.csv")
p = (ggplot(mpg_df, aes(x='displ', y='cty', fill='drv', size='hwy'))
     + scale_size(range=[515], breaks=[1540])
     + ggsize(700450)
    ) 

(p 
 + geom_point(shape=21, color='white',
              tooltips=layer_tooltips()
                          .anchor('top_right')
                          .min_width(180)
                          .format('cty''.0f')
                          .format('hwy''.0f')
                          .format('drv''{}wd')
                          .line('@manufacturer @model')
                          .line('cty/hwy [mpg]|@cty/@hwy')
                          .line('@|@class')
                          .line('drive train|@drv')
                          .line('@|@year')) 

关于右上角的显示设置(「tooltips」),更多见https://jetbrains.github.io/lets-plot-docs/pages/features/tooltips.html

import random

random.seed(123)
data = dict(cond=np.repeat(['A''B'], 10),
            xvar=[i + random.normalvariate(03for i in range(020)],
            yvar=[i + random.normalvariate(03for i in range(020)])

p = ggplot(data, aes(x='xvar', y='yvar')) + ggsize(600350)
p + geom_point(shape=1) + geom_smooth()
import pandas as pd

mpg_df = pd.read_csv('letplot.data/mpg.csv')
mpg_plot = ggplot(mpg_df, aes(x='displ', y='hwy'))
mpg_plot + geom_point(aes(color='drv'))\
         + geom_smooth(aes(color='drv'), method='loess', size=1)

density plot

geom_density2d, geom_density2df, geom_bin2d, geom_polygon, geom_point

cov0 = [[1-.8], [-.81]]
cov1 = [[1.8], [.81]]
cov2 = [[10.1], [.1.1]]

x0, y0 = np.random.multivariate_normal(mean=[-20], cov=cov0, size=400).T
x1, y1 = np.random.multivariate_normal(mean=[20], cov=cov1, size=400).T
x2, y2 = np.random.multivariate_normal(mean=[01], cov=cov2, size=400).T

data = dict(x=np.concatenate((x0, x1, x2)), y=np.concatenate((y0, y1, y2)))

p = ggplot(data, aes('x''y')) + ggsize(600300) + geom_point(color='black',
                                                                alpha=.1)

p + geom_density2d(aes(color='..level..')) \
+ scale_color_gradient(low='dark_green', high='yellow', guide=guide_colorbar(barheight=10, barwidth=300)) \
+ theme(legend_position='bottom')

分面图(facet plot)

data = pd.read_csv('letplot.data/mpg2.csv')
p = (ggplot(data, aes(x="engine horsepower", y="miles per gallon")) +
     geom_point(aes(color="origin of car")))
p + facet_wrap(facets=["origin of car""number of cylinders"], ncol=5)

拼图(GGBunch)

import math
import random
import numpy as np

n = 150
x_range = np.arange(-2 * math.pi, 2 * math.pi, 4 * math.pi / n)
y_range = np.sin(x_range) + np.array(
    [random.uniform(-.5.5for i in range(n)])
df = pd.DataFrame({'x': x_range, 'y': y_range})

p = ggplot(df, aes(x='x', y='y')) + geom_point(
    shape=21, fill='yellow', color='#8c564b')
p1 = p + geom_smooth(method='loess', size=1.5,
                     color='#d62728') + ggtitle('default (span = 0.5)')
p2 = p + geom_smooth(method='loess', span=.2, size=1.5,
                     color='#9467bd') + ggtitle('span = 0.2')
p3 = p + geom_smooth(method='loess', span=.7, size=1.5,
                     color='#1f77b4') + ggtitle('span = 0.7')
p4 = p + geom_smooth(method='loess', span=1, size=1.5,
                     color='#2ca02c') + ggtitle('span = 1')

bunch = GGBunch()
bunch.add_plot(p1, 00400300)
bunch.add_plot(p2, 4000400300)
bunch.add_plot(p3, 0300400300)
bunch.add_plot(p4, 400300400300)
bunch.show()

reference

更多功能,见github~

https://github.com/JetBrains/lets-plot

浏览 21
点赞
评论
收藏
分享

手机扫一扫分享

举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

举报