from math import log
from bisect import bisect_left
class MemoryModelMissed(Exception):
 pass
class MemoryModelError(Exception):
 pass
class ZeroCacheSizeNotSupported(MemoryModelError):
 pass
class MemoryTypeMissed(MemoryModelError):
 pass
class MemoryModelController:
 def __init__(self):
  self._default_model=None
  self._typed_models=dict()
 def add_memory_model(self,model,row_type=None):
  if row_type is None:
   self._default_model=model
  else:
   self._typed_models[row_type]=model
 def _get_model(self,row):
  for row_type,model in self._typed_models.items():
   if isinstance(row,row_type):
    return model
  if self._default_model:
   return self._default_model
  raise MemoryModelMissed(type(row))
 def estimate_traffic(self,row):
  return self._get_model(row).estimate_traffic(row)
 def get_carm_traffic(self,row):
  return self._get_model(row).get_carm_traffic(row)
class BaseMemoryTrafficEstimator:
 def __init__(self,target_cache_structure,dummy_cache_size=0):
  self.target_cache_sizes={}
  self._accumulated_target_cache_sizes={}
  for mem_type,config in target_cache_structure.items():
   cache_sizes=list(lvl[0]*lvl[1]for lvl in config)
   self.target_cache_sizes[mem_type]=cache_sizes
   self._accumulated_target_cache_sizes[mem_type]=get_accumulated_cache_sizes(cache_sizes,dummy_cache_size)
 def get_carm_traffic(self,row):
  res={}
  for mem_type,cache_sizes in self._accumulated_target_cache_sizes.items():
   tres=[]
   row_traffic=row.traffic.get(mem_type)
   if not row_traffic:
    continue
   for idx,size in enumerate(cache_sizes):
    if idx==0:
     tres.append(row_traffic.carm)
    else:
     tres.append((0,0))
   res[mem_type]=tres
  return res
def get_accumulated_cache_sizes(cache_sizes,dummy_cache_size):
 res=[dummy_cache_size]if dummy_cache_size else[]
 for lvl_size in cache_sizes:
  res.append(lvl_size+(res[-1]if res else 0))
 if res[0]<=0:
  raise ZeroCacheSizeNotSupported(res)
 return res
def get_scaling_coefficients(baseline_cache_sizes,baseline_cache_misses,target_cache_sizes):
 baseline_misses_n_sizes=list(zip(baseline_cache_misses,baseline_cache_sizes))
 if len(baseline_misses_n_sizes)==1:
  baseline_scaling_coeffs=[baseline_misses_n_sizes[0]+(-0.5,0)]
 else:
  baseline_scaling_coeffs=[]
  for l_idx,misses_size in enumerate(baseline_misses_n_sizes[:-1]):
   l_misses,l_size=misses_size
   r_misses,r_size=baseline_misses_n_sizes[l_idx+1]
   if l_misses==0:
    break
   if r_size==l_size:
    power=0
    linear=0
   elif r_misses==0 or r_misses>l_misses:
    power=0
    linear=(r_misses-l_misses)/(r_size-l_size)
   else:
    power=min(0,log(l_misses/r_misses)/log(l_size/r_size))
    linear=0
   baseline_scaling_coeffs.append((l_misses,l_size,power,linear))
  if len(baseline_scaling_coeffs)==0:
   baseline_scaling_coeffs=[(1,1,-0.5,0)]
  last_c=baseline_scaling_coeffs[-1]
  if last_c[0]and last_c[3]:
   baseline_scaling_coeffs.append((r_misses,r_size,0.0,0))
 target_scaling_coeffs=[]
 for lvl_size in target_cache_sizes:
  if lvl_size==0:
   raise ZeroCacheSizeNotSupported(target_cache_sizes)
  r_idx=bisect_left(baseline_cache_sizes,lvl_size,hi=len(baseline_scaling_coeffs))
  idx=max(0,r_idx-1)
  baseline_misses,baseline_size,baseline_pow,baseline_lin=baseline_scaling_coeffs[idx]
  target_scaling_coeffs.append(max(0,baseline_misses*pow(lvl_size/baseline_size,baseline_pow)+baseline_lin*(lvl_size-baseline_size)))
 if not target_scaling_coeffs:
  return[1]
 if target_scaling_coeffs[0]==0:
  target_scaling_coeffs[0]=1
 return[sc/target_scaling_coeffs[0]for sc in target_scaling_coeffs]
class SimpleMemoryTrafficEstimator(BaseMemoryTrafficEstimator):
 def __init__(self,*args):
  super().__init__(*args)
  self.get_measured_traffic=lambda x:{}
  self.get_baseline_traffic_n_sizes=lambda row:{}
 def estimate_traffic(self,row):
  target_traffic={}
  for mem_type,bl_mem_traffic_n_sizes in self.get_baseline_traffic_n_sizes(row).items():
   cache_sizes_bl,row_traffic_bl,footprint=bl_mem_traffic_n_sizes
   try:
    cache_sizes_tgt=self._accumulated_target_cache_sizes[mem_type]
   except KeyError:
    raise MemoryTypeMissed(mem_type)
   if not row_traffic_bl:
    target_traffic[mem_type]=[]
    continue
   if cache_sizes_bl:
    delta=cache_sizes_tgt[0]-cache_sizes_bl[0]
    cache_sizes_tgt=[x-delta for x in cache_sizes_tgt]
   else:
    cache_sizes_bl=cache_sizes_tgt[:1]
   cache_misses_bl=[tuple(t-f for t,f in zip(lvl_traffic,footprint))for lvl_traffic in row_traffic_bl]
   scale_coefficients=[get_scaling_coefficients(cache_sizes_bl,misses,cache_sizes_tgt)for misses in zip(*cache_misses_bl)]
   target_traffic[mem_type]=[tuple((t-f)*s+f for t,f,s in zip(row_traffic_bl[0],footprint,sl))for sl in zip(*scale_coefficients)]
  return target_traffic
class FlexMemoryTrafficEstimator(BaseMemoryTrafficEstimator):
 def __init__(self,target_memory_config,dummy_cache_size):
  super().__init__(target_memory_config,dummy_cache_size)
  self._cache_sizes=self._accumulated_target_cache_sizes
 def estimate_traffic(self,row):
  return row.get_flex_mem_traffic(self._cache_sizes)
