In [1]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import astropy as ast
import pandas as pd
import sys
In [2]:
# progress meter for big loops
# note: progress must go from 0 to 100 because reasons
def progress_meter(progress):
sys.stdout.write("\rloading... %.1f%%" % progress)
sys.stdout.flush()
In [47]:
# 100 random points
test_x = np.random.rand(100) * 100
test_y = np.random.rand(100) * 100
plt.scatter(test_x, test_y)
plt.show()
In [48]:
onc_gs = pd.DataFrame()
# get distances (cartesian)
for i in range(len(test_x)):
dist = []
for j in range(len(test_x)):
dist.append(np.sqrt(abs(test_x[i] - test_x[j])**2 + abs(test_y[i] - test_y[j])**2))
onc_gs.loc[:,str(i)] = dist
# mimic same columns as actual df
onc_gs.insert(0,'y',pd.Series(test_y))
onc_gs.insert(0,'x',pd.Series(test_x))
onc_gs.insert(0,'catID',pd.Series(np.random.randint(1,1000,100)))
onc_gs.insert(0,'catname','')
onc_gs.loc[:40,'catname'] = 'A'
onc_gs.loc[40:75,'catname'] = 'B'
onc_gs.loc[75:,'catname'] = 'C'
onc_gs.insert(0,'oncflag','')
onc_gs.insert(0,'oncID',np.nan)
In [49]:
# new source numbering starts at highest A number + 1
new_source = max(onc_gs.loc[onc_gs['catname'] == 'A', 'catID'].values) + 1
dist_crit = 7
# ====
exclude = set()
for k in range(len(onc_gs)):
if k not in exclude:
# find where dist < dist_crit
m = onc_gs.loc[onc_gs[str(k)] < dist_crit]
mindex = set(m[str(k)].index.tolist())
mindex_updated = set(m[str(k)].index.tolist())
mindex_same = False
iter_count = 0
# print 'initial', mindex
# keep adding match values until no new values are added
while mindex_same == False:
for x in mindex:
y = onc_gs.loc[onc_gs[str(x)] < dist_crit]
yindex = set(y[str(x)].index.tolist())
# print 'new', yindex
mindex_updated.update(yindex)
# print 'mindex', mindex
# print 'updated', mindex_updated
mindex_same = (mindex == mindex_updated)
mindex.update(mindex_updated)
iter_count += 1
exclude.update(mindex)
num_group = len(mindex)
match = onc_gs.loc[mindex,['catname','catID']]
# check for multiple sources in same catalog (any duplicates will flag as True)
if True: #match.duplicated(subset='catname',keep=False).any() == True:
onc_gs.loc[mindex,'oncflag'] += 'd' + str(num_group)
# use A number if it exists -- if multiple, use lowest
if ('A' in match['catname'].values) == True:
onc_gs.loc[mindex,'oncID'] = min(match.loc[match['catname'] == 'A','catID'].values)
# otherwise give it a new number
else:
onc_gs.loc[mindex,'oncID'] = new_source
new_source += 1
progress_meter(k*100./len(onc_gs))
onc_gs
Out[49]:
In [50]:
print onc_gs['oncflag'].value_counts()
In [51]:
grp_clr = {'d1':'black',\
'd2':'pink',\
'd3':'red',\
'd4':'orange',\
'd5':'green',\
'd6':'blue',\
'd7':'purple',\
'd8':'tan',\
'd10':'grey'\
}
cat_clr = {'A':'red',\
'B':'green',\
'C':'blue'\
}
f, (ax1, ax2) = plt.subplots(2, 1, figsize=(10,20))
for z in range(len(test_x)):
ax1.scatter(test_x[z], test_y[z], color=cat_clr[onc_gs.loc[z,'catname']])
ax2.scatter(test_x[z], test_y[z], color=grp_clr[onc_gs.loc[z,'oncflag']])
ax1.text(test_x[z], test_y[z], s=str(z))
ax1.axis('equal')
ax2.axis('equal')
In [52]:
onc_gs.to_csv('/Users/alin/Documents/check_group.csv')
In [ ]: