13
0
cdsc_reddit/clustering/lsi_base.py

45 lines
1.7 KiB
Python
Raw Normal View History

2021-05-10 20:46:49 +00:00
from clustering_base import clustering_job, clustering_result
from grid_sweep import grid_sweep, twoway_grid_sweep
2021-05-10 20:46:49 +00:00
from dataclasses import dataclass
from itertools import chain
from pathlib import Path
class lsi_mixin():
def set_lsi_dims(self, lsi_dims):
self.lsi_dims = lsi_dims
@dataclass
class lsi_result_mixin:
lsi_dimensions:int
class lsi_grid_sweep(grid_sweep):
def __init__(self, jobtype, subsweep, inpath, lsi_dimensions, outpath, *args, **kwargs):
self.jobtype = jobtype
self.subsweep = subsweep
inpath = Path(inpath)
if lsi_dimensions == 'all':
2021-08-12 05:48:33 +00:00
lsi_paths = list(inpath.glob("*.feather"))
2021-05-10 20:46:49 +00:00
else:
2021-05-14 05:26:03 +00:00
lsi_paths = [inpath / (str(dim) + '.feather') for dim in lsi_dimensions]
2021-05-10 20:46:49 +00:00
2021-08-12 05:48:33 +00:00
print(lsi_paths)
2021-05-14 05:26:03 +00:00
lsi_nums = [int(p.stem) for p in lsi_paths]
2021-05-10 20:46:49 +00:00
self.hasrun = False
self.subgrids = [self.subsweep(lsi_path, outpath, lsi_dim, *args, **kwargs) for lsi_dim, lsi_path in zip(lsi_nums, lsi_paths)]
self.jobs = list(chain(*map(lambda gs: gs.jobs, self.subgrids)))
class twoway_lsi_grid_sweep(twoway_grid_sweep):
2022-06-09 00:27:37 +00:00
def __init__(self, jobtype, subsweep, inpath, lsi_dimensions, outpath, args1, args2):
self.jobtype = jobtype
self.subsweep = subsweep
inpath = Path(inpath)
if lsi_dimensions == 'all':
lsi_paths = list(inpath.glob("*.feather"))
else:
lsi_paths = [inpath / (str(dim) + '.feather') for dim in lsi_dimensions]
lsi_nums = [int(p.stem) for p in lsi_paths]
self.hasrun = False
2022-06-09 00:27:37 +00:00
self.subgrids = [self.subsweep(lsi_path, outpath, lsi_dim, args1, args2) for lsi_dim, lsi_path in zip(lsi_nums, lsi_paths)]
self.jobs = list(chain(*map(lambda gs: gs.jobs, self.subgrids)))