Skip to content

Commit 85de293

Browse files
authored
added files
1 parent 5e8f2e1 commit 85de293

File tree

17 files changed

+483
-0
lines changed

17 files changed

+483
-0
lines changed

LICENSE

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
MIT License
2+
3+
Copyright (c) Subhadarshi Panda
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
SOFTWARE.

MANIFEST.in

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Things to include in the built package (besides the packages defined in setup.py)
2+
include README.md
3+
include LICENSE
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
import numpy as np
2+
import torch
3+
from tqdm import tqdm
4+
5+
6+
def initialize(X, num_clusters):
7+
num_samples = len(X)
8+
indices = np.random.choice(num_samples, num_clusters, replace=False)
9+
initial_state = X[indices]
10+
return initial_state
11+
12+
13+
def kmeans(
14+
X,
15+
num_clusters,
16+
distance='euclidean',
17+
tol=1e-4,
18+
device=torch.device('cpu')
19+
):
20+
print(f'running k-means on {device}..')
21+
22+
if distance == 'euclidean':
23+
pairwise_distance_function = pairwise_distance
24+
elif distance == 'cosine':
25+
pairwise_distance_function = pairwise_cosine
26+
else:
27+
raise NotImplementedError
28+
29+
# convert to float
30+
X = X.float()
31+
32+
# transfer to device
33+
X = X.to(device)
34+
35+
# initialize
36+
initial_state = initialize(X, num_clusters)
37+
38+
iteration = 0
39+
tqdm_meter = tqdm(desc='[running kmeans]')
40+
while True:
41+
dis = pairwise_distance_function(X, initial_state)
42+
43+
choice_cluster = torch.argmin(dis, dim=1)
44+
45+
initial_state_pre = initial_state.clone()
46+
47+
for index in range(num_clusters):
48+
selected = torch.nonzero(choice_cluster == index).squeeze().to(device)
49+
50+
selected = torch.index_select(X, 0, selected)
51+
initial_state[index] = selected.mean(dim=0)
52+
53+
center_shift = torch.sum(
54+
torch.sqrt(
55+
torch.sum((initial_state - initial_state_pre) ** 2, dim=1)
56+
))
57+
58+
# increment iteration
59+
iteration = iteration + 1
60+
61+
# update tqdm meter
62+
tqdm_meter.set_postfix(
63+
iteration=f'{iteration}',
64+
center_shift=f'{center_shift ** 2:0.6f}',
65+
tol=f'{tol:0.6f}'
66+
)
67+
tqdm_meter.update()
68+
if center_shift ** 2 < tol:
69+
break
70+
71+
return choice_cluster.cpu(), initial_state.cpu()
72+
73+
74+
def kmeans_predict(
75+
X,
76+
cluster_centers,
77+
distance='euclidean',
78+
device=torch.device('cpu')
79+
):
80+
print(f'predicting on {device}..')
81+
82+
if distance == 'euclidean':
83+
pairwise_distance_function = pairwise_distance
84+
elif distance == 'cosine':
85+
pairwise_distance_function = pairwise_cosine
86+
else:
87+
raise NotImplementedError
88+
89+
# convert to float
90+
X = X.float()
91+
92+
# transfer to device
93+
X = X.to(device)
94+
95+
dis = pairwise_distance_function(X, cluster_centers)
96+
choice_cluster = torch.argmin(dis, dim=1)
97+
98+
return choice_cluster.cpu()
99+
100+
101+
def pairwise_distance(data1, data2, device=torch.device('cpu')):
102+
# transfer to device
103+
data1, data2 = data1.to(device), data2.to(device)
104+
105+
# N*1*M
106+
A = data1.unsqueeze(dim=1)
107+
108+
# 1*N*M
109+
B = data2.unsqueeze(dim=0)
110+
111+
dis = (A - B) ** 2.0
112+
# return N*N matrix for pairwise distance
113+
dis = dis.sum(dim=-1).squeeze()
114+
return dis
115+
116+
117+
def pairwise_cosine(data1, data2, device=torch.device('cpu')):
118+
# transfer to device
119+
data1, data2 = data1.to(device), data2.to(device)
120+
121+
# N*1*M
122+
A = data1.unsqueeze(dim=1)
123+
124+
# 1*N*M
125+
B = data2.unsqueeze(dim=0)
126+
127+
# normalize the points | [0.3, 0.4] -> [0.3/sqrt(0.09 + 0.16), 0.4/sqrt(0.09 + 0.16)] = [0.3/0.5, 0.4/0.5]
128+
A_normalized = A / A.norm(dim=-1, keepdim=True)
129+
B_normalized = B / B.norm(dim=-1, keepdim=True)
130+
131+
cosine = A_normalized * B_normalized
132+
133+
# return N*N matrix for pairwise distance
134+
cosine_dis = 1 - cosine.sum(dim=-1).squeeze()
135+
return cosine_dis

build/lib/kmeans_pytorch/main.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
4+
__version__ = "0.2"
5+
6+
7+
def main():
8+
print("TODO")
9+
10+
11+
if __name__ == "__main__":
12+
main()
4.09 KB
Binary file not shown.

dist/kmeans_pytorch-0.2.tar.gz

3.98 KB
Binary file not shown.

kmeans_pytorch.egg-info/PKG-INFO

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
Metadata-Version: 2.1
2+
Name: kmeans-pytorch
3+
Version: 0.2
4+
Summary: UNKNOWN
5+
Home-page: https://github.com/subhadarship/kmeans_pytorch
6+
Author: Subhadarshi Panda
7+
Author-email: subhadarshipanda08@gmail.com
8+
License: License :: OSI Approved :: MIT License
9+
Description: # K Means using PyTorch
10+
PyTorch implementation of kmeans for utilizing GPU
11+
12+
# Getting Started
13+
```
14+
15+
import torch
16+
import numpy as np
17+
from kmeans_pytorch import kmeans
18+
19+
# data
20+
data_size, dims, num_clusters = 1000, 2, 3
21+
x = np.random.randn(data_size, dims) / 6
22+
x = torch.from_numpy(x)
23+
24+
# kmeans
25+
cluster_ids_x, cluster_centers = kmeans(
26+
X=x, num_clusters=num_clusters, distance='euclidean', device=torch.device('cuda:0')
27+
)
28+
```
29+
30+
see [`example.ipynb`](https://github.com/subhadarship/kmeans_pytorch/blob/master/example.ipynb) for a more elaborate example
31+
32+
# Requirements
33+
* [PyTorch](http://pytorch.org/) version >= 1.0.0
34+
* Python version >= 3.6
35+
36+
# Installation
37+
38+
install with `pip`:
39+
```
40+
pip install kmeans-pytorch
41+
```
42+
43+
**Installing from source**
44+
45+
To install from source and develop locally:
46+
```
47+
git clone https://github.com/subhadarship/kmeans_pytorch
48+
cd kmeans_pytorch
49+
pip install --editable .
50+
```
51+
52+
# Notes
53+
- useful when clustering large number of samples
54+
- utilizes GPU for faster matrix computations
55+
- support euclidean and cosine distances (for now)
56+
57+
Platform: UNKNOWN
58+
Classifier: Programming Language :: Python
59+
Requires-Python: >=3.6
60+
Description-Content-Type: text/markdown
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
LICENSE
2+
MANIFEST.in
3+
README.md
4+
setup.py
5+
kmeans_pytorch/__init__.py
6+
kmeans_pytorch/main.py
7+
kmeans_pytorch.egg-info/PKG-INFO
8+
kmeans_pytorch.egg-info/SOURCES.txt
9+
kmeans_pytorch.egg-info/dependency_links.txt
10+
kmeans_pytorch.egg-info/entry_points.txt
11+
kmeans_pytorch.egg-info/not-zip-safe
12+
kmeans_pytorch.egg-info/top_level.txt
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[console_scripts]
2+
kmeans_pytorch = kmeans_pytorch.main:main
3+

0 commit comments

Comments
 (0)