add cluster selection to visualization
This commit is contained in:
		
							parent
							
								
									e6294b5b90
								
							
						
					
					
						commit
						dbe4c87f8b
					
				| @ -9,16 +9,38 @@ import fire | ||||
| import numpy as np | ||||
| 
 | ||||
| def base_plot(plot_data): | ||||
| 
 | ||||
| #    base = base.encode(alt.Color(field='color',type='nominal',scale=alt.Scale(scheme='category10'))) | ||||
| 
 | ||||
|     cluster_dropdown = alt.binding_select(options=[str(c) for c in sorted(set(plot_data.cluster))]) | ||||
| 
 | ||||
|     subreddit_dropdown = alt.binding_select(options=sorted(plot_data.subreddit)) | ||||
| 
 | ||||
|     cluster_click_select = alt.selection_single(on='click,',fields=['cluster'], bind=cluster_dropdown, name=' ') | ||||
|     # cluster_select = alt.selection_single(fields=['cluster'], bind=cluster_dropdown, name='cluster') | ||||
|     # cluster_select_and = cluster_click_select & cluster_select | ||||
|     # | ||||
|     #    subreddit_select = alt.selection_single(on='click',fields=['subreddit'],bind=subreddit_dropdown,name='subreddit_click') | ||||
|      | ||||
|     color = alt.condition(cluster_click_select , | ||||
|                           alt.Color(field='color',type='nominal',scale=alt.Scale(scheme='category10')), | ||||
|                           alt.value("lightgray")) | ||||
|    | ||||
|      | ||||
|     base = alt.Chart(plot_data).mark_text().encode( | ||||
|         alt.X('x',axis=alt.Axis(grid=False),scale=alt.Scale(domain=(-65,65))), | ||||
|         alt.Y('y',axis=alt.Axis(grid=False),scale=alt.Scale(domain=(-65,65))), | ||||
|         color=color, | ||||
|         text='subreddit') | ||||
| 
 | ||||
|     base = base.add_selection(cluster_click_select) | ||||
|   | ||||
| 
 | ||||
|     return base | ||||
| 
 | ||||
| def zoom_plot(plot_data): | ||||
|     chart = base_plot(plot_data) | ||||
|     chart = chart.encode(alt.Color(field='color',type='nominal',scale=alt.Scale(scheme='category10'))) | ||||
| 
 | ||||
|     chart = chart.interactive() | ||||
|     chart = chart.properties(width=1275,height=1000) | ||||
| 
 | ||||
| @ -35,7 +57,7 @@ def viewport_plot(plot_data): | ||||
|         alt.X('x',axis=alt.Axis(grid=False)), | ||||
|         alt.Y('y',axis=alt.Axis(grid=False)), | ||||
|     ) | ||||
| 
 | ||||
|     | ||||
|     viewport = viewport.properties(width=600,height=400) | ||||
| 
 | ||||
|     viewport1 = viewport.add_selection(selector1) | ||||
| @ -52,7 +74,7 @@ def viewport_plot(plot_data): | ||||
|                      alt.Y('y',axis=alt.Axis(grid=False),scale=alt.Scale(domain=selectory2)) | ||||
|     ) | ||||
| 
 | ||||
|     sr = sr.encode(alt.Color(field='color',type='nominal',scale=alt.Scale(scheme='category10'))) | ||||
| 
 | ||||
|     sr = sr.properties(width=1275,height=600) | ||||
| 
 | ||||
| 
 | ||||
| @ -71,15 +93,29 @@ def assign_cluster_colors(tsne_data, clusters, n_colors, n_neighbors = 4): | ||||
|     distances = np.empty(shape=(centroids.shape[0],centroids.shape[0])) | ||||
| 
 | ||||
|     groups = tsne_data.groupby('cluster') | ||||
|     for centroid in centroids.itertuples(): | ||||
|         c_dists = groups.apply(lambda r: min(np.sqrt(np.square(centroid.x - r.x) + np.square(centroid.y-r.y)))) | ||||
|         distances[:,centroid.Index] = c_dists | ||||
|      | ||||
|     points = np.array(tsne_data.loc[:,['x','y']]) | ||||
|     centers = np.array(centroids.loc[:,['x','y']]) | ||||
| 
 | ||||
|     # point x centroid | ||||
|     point_center_distances = np.linalg.norm((points[:,None,:] - centers[None,:,:]),axis=-1) | ||||
|      | ||||
|     # distances is cluster x point | ||||
|     for gid, group in groups: | ||||
|         c_dists = point_center_distances[group.index.values,:].min(axis=0) | ||||
|         distances[group.cluster.values[0],] = c_dists         | ||||
| 
 | ||||
|     # nbrs = NearestNeighbors(n_neighbors=n_neighbors).fit(centroids)  | ||||
|     # distances, indices = nbrs.kneighbors() | ||||
| 
 | ||||
|     nbrs = NearestNeighbors(n_neighbors=n_neighbors,metric='precomputed').fit(distances)  | ||||
|     distances, indices = nbrs.kneighbors() | ||||
|     nearest = distances.argpartition(n_neighbors,0) | ||||
|     indices = nearest[:n_neighbors,:].T | ||||
|     # neighbor_distances = np.copy(distances) | ||||
|     # neighbor_distances.sort(0) | ||||
|     # neighbor_distances = neighbor_distances[0:n_neighbors,:] | ||||
|      | ||||
|     # nbrs = NearestNeighbors(n_neighbors=n_neighbors,metric='precomputed').fit(distances)  | ||||
|     # distances, indices = nbrs.kneighbors() | ||||
| 
 | ||||
|     color_assignments = np.repeat(-1,len(centroids)) | ||||
| 
 | ||||
| @ -119,13 +155,13 @@ def build_visualization(tsne_data, clusters, output): | ||||
| if __name__ == "__main__": | ||||
|     fire.Fire(build_visualization) | ||||
| 
 | ||||
| # commenter_data = pd.read_feather("tsne_author_fit.feather") | ||||
| # clusters = pd.read_feather('author_3000_clusters.feather') | ||||
| # commenter_data = assign_cluster_colors(commenter_data,clusters,10,8) | ||||
| # commenter_zoom_plot = zoom_plot(commenter_data) | ||||
| # commenter_viewport_plot = viewport_plot(commenter_data) | ||||
| # commenter_zoom_plot.save("subreddit_commenters_tsne_3000.html") | ||||
| # commenter_viewport_plot.save("subreddit_commenters_tsne_3000_viewport.html") | ||||
| commenter_data = pd.read_feather("tsne_author_fit.feather") | ||||
| clusters = pd.read_feather('author_3000_clusters.feather') | ||||
| commenter_data = assign_cluster_colors(commenter_data,clusters,10,8) | ||||
| commenter_zoom_plot = zoom_plot(commenter_data) | ||||
| commenter_viewport_plot = viewport_plot(commenter_data) | ||||
| commenter_zoom_plot.save("subreddit_commenters_tsne_3000.html") | ||||
| commenter_viewport_plot.save("subreddit_commenters_tsne_3000_viewport.html") | ||||
| 
 | ||||
| # chart = chart.properties(width=10000,height=10000) | ||||
| # chart.save("test_tsne_whole.svg") | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user