import pandas as pd
import requests
from bs4 import BeautifulSoup
import numpy as np
import re
import matplotlib.pyplot as plt
from scipy.stats import linregress
from scipy.special import logit


def get_data(url, table_id, col_ids, col_names):
    response = requests.get(url)
    
    soup = BeautifulSoup(response.content, 'html.parser')
    
    tables = soup.find_all('table', {'class': 'wikitable'})
    target_table = tables[table_id]
    
    rows = target_table.find_all('tr')
    data = []
    
    for row in rows:
        cols = row.find_all('td')
        if len(cols) > 0:
            try:
                data.append([cols[i].text.strip() for i in col_ids])
            except IndexError:
                col_ids[-1] = col_ids[-1] - 1
                data.append([cols[i].text.strip() for i in col_ids])
                col_ids[-1] = col_ids[-1] + 1
    
    df = pd.DataFrame(data, columns=col_names)
    return df

def parse_century_string(century_string):
    # Regular expression to extract the century number and the era
    match = re.match(r'(\d+)(st|nd|rd|th) century (BCE|BC|CE)', century_string)
    if not match:
        print(century_string)
        raise ValueError("Invalid format")

    # Extract the century number and era
    century, era = int(match.group(1)), match.group(3)

    # Calculate the start and end years for CE
    if era == 'CE':
        start_year = (century - 1) * 100 + 1
        end_year = century * 100

    # Calculate the start and end years for BCE
    else:
        start_year = -century * 100
        end_year = -(century - 1) * 100 - 1

    return start_year, end_year

def convert_year(year, start=True):
    if year[-1] == 's':
        year = year.rstrip('s')
        if start:
            return int(year)
        else:
            year2 = year.rstrip('0')
            year2 = year2 + '9' * (len(year) - len(year2))
            return int(year2)
    if 'century' in year:
        start_year, end_year = parse_century_string(year)
        if start:
            return start_year
        else:
            return end_year
    if 'BCE' in year:
        # Convert BCE to a negative number
        return -int(year.replace('BCE', '').strip())
    elif year == 'Present' or year == 'ongoing':
        # Convert 'Present' to 2023
        return 2023
    elif 'February' in year:
        return int(year.replace('February', '').strip())
    else:
        # For normal years, just convert to integer
        return int(year)

population_df = pd.read_csv('population.csv')

# Filtering out the global population data (assuming it's represented by 'World')
global_population_df = population_df[population_df['Entity'] == 'World']

# Renaming the 'Population (historical estimates)' column for easier access
global_population_df = global_population_df.rename(columns={'Population (historical estimates)': 'Population'})

# Removing unnecessary columns
global_population_df = global_population_df[['Year', 'Population']]

# Geometric Interpolation Function
def geometric_interpolation(year, df):
    if year in df['Year'].values:
        return df[df['Year'] == year]['Population'].iloc[0]
    else:
        lower_years = df[df['Year'] < year]
        upper_years = df[df['Year'] > year]
        if lower_years.empty or upper_years.empty:
            return np.nan  # No data to interpolate from
        else:
            lower_year, lower_pop = lower_years.iloc[-1]
            upper_year, upper_pop = upper_years.iloc[0]
            return lower_pop * ((upper_pop / lower_pop) ** ((year - lower_year) / (upper_year - lower_year)))


def isdig(x):
    return str.isdigit(x) or x == '.'


# The 'Death Toll (Total)' column contains commas in the numbers, so we need to remove them and convert to integer
def clean_population(value):
    if 'million' in value:
        return clean_population(value.replace('million', '')) * 1000000

    value = value.replace('–', '-').replace('—', '-').replace('but', '-')
    if '-' in value:
        estimates = np.array(list(map(clean_population, value.split('-'))))
        return np.exp(np.mean(np.log(estimates)))

    # Remove content within square brackets using regular expression
    cleaned_value = re.sub(r"\[.*?\]", "", str(value))
    cleaned_value = re.sub(r"\(.*?\)", "", cleaned_value)
    # Remove any remaining non-numeric characters
    cleaned_value = ''.join(filter(isdig, cleaned_value))
    return float(cleaned_value) if cleaned_value else 0


# Function to format numbers in millions (M) or thousands (K)
def format_millions_or_thousands(value):
    if value >= 1e6:  # Millions
        return f"{round(value / 1e6)}M"
    else:  # Thousands
        return f"{round(value / 1e3)}K"

# Function to format population in billions (B) or millions (M)
def format_billions_or_millions(value):
    if value >= 1e9:  # Billions
        return f"{round(value / 1e9, 2)}B"
    else:  # Millions
        return f"{round(value / 1e6)}M"

# Function to round percent to the nearest 0.1%
def round_percent(value):
    return f"{round(value, 1)}%"

def process_and_clean(df):
    df['Start Population'] = df['Start'].apply(lambda x: geometric_interpolation(x, global_population_df))
    df['Death Toll (Total)'] = df['Death Toll (Total)'].apply(clean_population)
    # Calculating 'Death Toll (Percent)' as a fraction of the 'Start Population'
    df['Death Toll (Percent)'] = df['Death Toll (Total)'] / df['Start Population'] * 100
    # Sorting the DataFrame by 'Death Toll (Percent)' in descending order
    df_sorted = df.sort_values(by='Death Toll (Percent)', ascending=False)
    return df_sorted

def clean2(df_sorted):
    # Dropping rows where 'Death Toll (Percent)' is below 0.1%
    df_filtered = df_sorted[df_sorted['Death Toll (Percent)'] >= 0.1]

    # Applying the formatting functions to the DataFrame
    df_filtered['Death Toll (Total)'] = df_filtered['Death Toll (Total)'].apply(format_millions_or_thousands)
    df_filtered['Start Population'] = df_filtered['Start Population'].apply(format_billions_or_millions)
    return df_filtered



# Function to convert dates into start and end dates
def adjust_year(start, end):
    """
    Adjust the end year based on the start year to handle cases like:
    1319-20 -> 1320, 1100-03 -> 1103, etc.
    """
    # Convert years to strings to work with their digits
    start_str = str(start)
    end_str = str(end)

    # If the end year is shorter than the start year, adjust it
    if len(end_str) < len(start_str):
        # Take the necessary number of digits from the start year
        adjusted_end_str = start_str[:len(start_str) - len(end_str)] + end_str
        adjusted_end = int(adjusted_end_str)
    else:
        # If the end year is not shorter, no adjustment is needed
        adjusted_end = end

    return adjusted_end

def convert_start_date(date_str):
    # Handle decades like '950s'
    if 's' in date_str:
        return int(date_str.replace('s', ''))
    else:
        return int(date_str)

def convert_end_date(date_str, start_date):
    if 's' in date_str:
        date = int(date_str.replace('s', '')) + 9
    else:
        date = int(date_str)
    return int(adjust_year(str(start_date), str(date)))


months = ['January', 'February', 'March', 'April', 'May', 'June', 'July', 'August', 'September', 'October', 'November', 'December']

def convert_date(date_str):
    for month in months:
        if month in date_str:
            if ',' in date_str:
                return convert_date(date_str.split(',')[1])
            else:
                return convert_date(date_str.replace(month, ''))

    if ',' in date_str:
        date_strs = date_str.split(',')
        dates = [convert_date(s) for s in date_strs]
        return [min([d[0] for d in dates]), max([d[1] for d in dates])]

    date_str = date_str.replace('–', '-').replace('—', '-').replace('−', '-')
    date_str = date_str.replace('and', '')
    date_str = re.sub(r"\[.*?\]", "", date_str)
    date_str = re.sub(r"\(.*?\)", "", date_str)
    date_str = date_str.replace('BCE', 'BC')
    date_str = date_str.replace('\xa0', ' ')
    # Handle BCE
    if 'BC' in date_str:
        dates = date_str.replace(' BC', '').split('-')
        if len(dates) == 1:
            dates = dates + dates
        return [-int(d) for d in dates][::-1]  # Reverse to get start and end in correct order
    
    # Handle 'Present'
    if 'present' in date_str.lower():
        date_str = date_str.replace('present', '2023')
    
    # Handle date ranges
    if '-' in date_str:
        dates = date_str.split('-')
        start_date = convert_start_date(dates[0])
        end_date = convert_end_date(dates[1], start_date)
        return [start_date, end_date]
    
    
    # Single year
    start_date = convert_start_date(date_str)
    end_date = convert_end_date(date_str, start_date)
    return [start_date, end_date]


# Wars
df_wars = get_data("https://en.m.wikipedia.org/wiki/List_of_anthropogenic_disasters_by_death_toll", 0,
                   [0, 3, 5, 6], ['Event', 'Death Toll (Total)', 'Start', 'End'])
df_wars['Start'] = df_wars['Start'].apply(lambda x: convert_year(x, start=True))
df_wars['End'] = df_wars['End'].apply(lambda x: convert_year(x, start=False))
df_wars['Start'] = df_wars['Start'].astype(int)
df_wars = clean2(process_and_clean(df_wars))

# Forced Labor
df_labor = get_data("https://en.m.wikipedia.org/wiki/List_of_anthropogenic_disasters_by_death_toll", 2,
                   [0, 3, 5, 6], ['Event', 'Death Toll (Total)', 'Start', 'End'])
df_labor['Start'] = df_labor['Start'].apply(lambda x: convert_year(x, start=True))
df_labor['End'] = df_labor['End'].apply(lambda x: convert_year(x, start=False))
df_labor['Start'] = df_labor['Start'].astype(int)
df_labor = clean2(process_and_clean(df_labor))

# Famines
df_famines = get_data("https://en.m.wikipedia.org/wiki/List_of_famines", 0,
              [1, 3, 0], ['Event', 'Death Toll (Total)', 'Date'])
df_famines[['Start', 'End']] = df_famines['Date'].apply(convert_date).tolist()
df_famines = clean2(process_and_clean(df_famines))

df_epidemics = get_data("https://en.m.wikipedia.org/wiki/List_of_epidemics_and_pandemics", 0,
                        [1, 3, 6], ['Event', 'Death Toll (Total)', 'Date'])
df_epidemics[['Start', 'End']] = df_epidemics['Date'].apply(convert_date).tolist()
df_epidemics = clean2(process_and_clean(df_epidemics))


df_disasters = get_data('https://en.m.wikipedia.org/wiki/List_of_natural_disasters_by_death_toll', 0,
                        [1, 0, 3], ['Event', 'Death Toll (Total)', 'Date'])
df_disasters[['Start', 'End']] = df_disasters['Date'].apply(convert_date).tolist()
df_disasters = clean2(process_and_clean(df_disasters))

df_wars['Type'] = 'War'
df_labor['Type'] = 'Forced Labor'
df_famines.drop(columns=['Date'])
df_famines['Type'] = 'Famine'
df_epidemics.drop(columns=['Date'])
df_epidemics['Type'] = 'Epidemic'
df_disasters.drop(columns=['Date'])
df_disasters['Type'] = 'Natural Disaster'


df_all = pd.concat([df_wars, df_labor, df_famines, df_epidemics, df_disasters])
df_all['Death Toll (Logit)'] = np.exp(logit(df_all['Death Toll (Percent)'] / 100.0))

metric = 'Death Toll (Percent)'
df_all = df_all.sort_values(by=metric, ascending=False)

plt.figure(figsize=(10, 6))
for type_category in df_all['Type'].unique():
    subset = df_all[df_all['Type'] == type_category]
    plt.scatter(subset['Start'], subset[metric], label=type_category)

plt.xlabel('Start Year')
plt.ylabel(metric)
plt.yscale('log')
plt.title('Scatter Plot by Type')
plt.legend(title='Type')
plt.show()


df_all = df_all.reset_index(drop=True)
log_x = np.log(df_all.index + 1)
log_y = np.log(df_all[metric])
slope, intercept, r_value, p_value, std_err = linregress(log_x, log_y)
fitted_y = np.exp(intercept) * ((df_all.index+1)**slope)


plt.figure(figsize=(10, 6))
plt.loglog(df_all.index + 1, df_all[metric], marker='o', linestyle='')
plt.loglog(df_all.index + 1, fitted_y, label=f'Power Law Fit: exponent = {slope:.2f}', color='red')
plt.xlabel('Sorted Index (Descending by Death Toll)')
plt.ylabel(metric)
plt.title('Log-Log Plot of Death Toll vs Sorted Index')
plt.legend()
plt.show()

df_sudden = df_all[df_all['End'] - df_all['Start'] <= 10]
df_strict = df_sudden[df_sudden['Death Toll (Percent)'] >= 1]
df_strict['Death Toll (Percent)'] = df_strict['Death Toll (Percent)'].apply(round_percent)
df_strict.to_csv('events_strict.csv', index=False)

df_all['Death Toll (Percent)'] = df_all['Death Toll (Percent)'].apply(round_percent)
df_all.to_csv('events_all.csv', index=False)

import statsmodels.api as sm
import statsmodels.sandbox as sbx

years = []
counts = []
step = 10
for decade in range(1500, 2019, step):
    count = ((df_all['Start'] >= decade) & (df_all['Start'] < decade + step)).sum()
    counts.append(count)
    years.append(decade)

counts = np.array(counts)
years = np.array(years)
plt.bar(years, counts, width=10)
plt.title('Catastrophes Per Decade')
plt.xlabel('Starting Year')
plt.ylabel('# of Catastrophes')
plt.show()

# Ljung-Box p-value for 1500-1900
print(sm.stats.diagnostic.acorr_ljungbox(counts[:40], lags=1))
#     lb_stat  lb_pvalue
# 1  0.840637   0.359215

# Wald-Wolfowitz p-value for 1500-1900
print(sbx.stats.runs.runstest_1samp(counts[:40], cutoff=0.5, correction=True))
# (-1.1176210768562695, 0.2637288630915612)
