import streamlit as st
import pandas as pd
import json
from scenario import Channel, Scenario
import numpy as np
from plotly.subplots import make_subplots
import plotly.graph_objects as go
from scenario import class_to_dict
from collections import OrderedDict
import io
import plotly
from pathlib import Path
import pickle
import yaml
from yaml import SafeLoader
from streamlit.components.v1 import html
import smtplib
from scipy.optimize import curve_fit
from sklearn.metrics import r2_score
from scenario import class_from_dict, class_convert_to_dict
import os
import base64
import sqlite3
import datetime
from scenario import numerize
import psycopg2
#
import re
import bcrypt
import os
import json
import glob
import pickle
import streamlit as st
import streamlit as st
import pandas as pd
import json
from scenario import Channel, Scenario
import numpy as np
from plotly.subplots import make_subplots
import plotly.graph_objects as go
from scenario import class_to_dict
from collections import OrderedDict
import io
import plotly
from pathlib import Path
import pickle
import yaml
from yaml import SafeLoader
from streamlit.components.v1 import html
import smtplib
from scipy.optimize import curve_fit
from sklearn.metrics import r2_score
from scenario import class_from_dict, class_convert_to_dict
import os
import base64
import sqlite3
import datetime
from scenario import numerize
import sqlite3
# # schema = db_cred["schema"]
color_palette = [
"#F3F3F0",
"#5E7D7E",
"#2FA1FF",
"#00EDED",
"#00EAE4",
"#304550",
"#EDEBEB",
"#7FBEFD",
"#003059",
"#A2F3F3",
"#E1D6E2",
"#B6B6B6",
]
CURRENCY_INDICATOR = "$"
db_cred = None
# database_file = r"DB/User.db"
# conn = sqlite3.connect(database_file, check_same_thread=False) # connection with sql db
# c = conn.cursor()
# def query_excecuter_postgres(
# query,
# db_cred,
# params=None,
# insert=True,
# insert_retrieve=False,
# ):
# """
# Executes a SQL query on a PostgreSQL database, handling both insert and select operations.
# Parameters:
# query (str): The SQL query to be executed.
# params (tuple, optional): Parameters to pass into the SQL query for parameterized execution.
# insert (bool, default=True): Flag to determine if the query is an insert operation (default) or a select operation.
# insert_retrieve (bool, default=False): Flag to determine if the query should insert and then return the inserted ID.
# """
# # Database connection parameters
# dbname = db_cred["dbname"]
# user = db_cred["user"]
# password = db_cred["password"]
# host = db_cred["host"]
# port = db_cred["port"]
# try:
# # Establish connection to the PostgreSQL database
# conn = psycopg2.connect(
# dbname=dbname, user=user, password=password, host=host, port=port
# )
# except psycopg2.Error as e:
# st.warning(f"Unable to connect to the database: {e}")
# st.stop()
# # Create a cursor object to interact with the database
# c = conn.cursor()
# try:
# # Execute the query with or without parameters
# if params:
# c.execute(query, params)
# else:
# c.execute(query)
# if not insert:
# # If not an insert operation, fetch and return the results
# results = c.fetchall()
# return results
# elif insert_retrieve:
# # If insert and retrieve operation, fetch and return the results
# conn.commit()
# return c.fetchall()
# else:
# conn.commit()
# except Exception as e:
# st.write(f"Error executing query: {e}")
# finally:
# conn.close()
db_path = os.path.join("imp_db.db")
def query_excecuter_postgres(
query, db_path=None, params=None, insert=True, insert_retrieve=False, db_cred=None
):
"""
Executes a SQL query on a SQLite database, handling both insert and select operations.
Parameters:
query (str): The SQL query to be executed.
db_path (str): Path to the SQLite database file.
params (tuple, optional): Parameters to pass into the SQL query for parameterized execution.
insert (bool, default=True): Flag to determine if the query is an insert operation (default) or a select operation.
insert_retrieve (bool, default=False): Flag to determine if the query should insert and then return the inserted ID.
"""
try:
# Construct a cross-platform path to the database
db_dir = os.path.join("db")
os.makedirs(db_dir, exist_ok=True) # Make sure the directory exists
db_path = os.path.join(db_dir, "imp_db.db")
# Establish connection to the SQLite database
conn = sqlite3.connect(db_path)
except sqlite3.Error as e:
st.warning(f"Unable to connect to the SQLite database: {e}")
st.stop()
# Create a cursor object to interact with the database
c = conn.cursor()
# Prepare the query with proper placeholders
if params:
# Handle the `IN (?)` clause dynamically
query = query.replace("IN (?)", f"IN ({','.join(['?' for _ in params])})")
c.execute(query, params)
else:
c.execute(query)
try:
if not insert:
# If not an insert operation, fetch and return the results
results = c.fetchall()
return results
elif insert_retrieve:
# If insert and retrieve operation, commit and return the last inserted row ID
conn.commit()
return c.lastrowid
else:
# For standard insert operations, commit the transaction
conn.commit()
except Exception as e:
st.write(f"Error executing query: {e}")
finally:
conn.close()
def update_summary_df():
"""
Updates the 'project_summary_df' in the session state with the latest project
summary information based on the most recent updates.
This function executes a SQL query to retrieve project metadata from a database
and stores the result in the session state.
Uses:
- query_excecuter_postgres(query, params=params, insert=False): A function that
executes the provided SQL query on a PostgreSQL database.
Modifies:
- st.session_state['project_summary_df']: Updates the dataframe with columns:
'Project Number', 'Project Name', 'Last Modified Page', 'Last Modified Time'.
"""
query = f"""
WITH LatestUpdates AS (
SELECT
prj_id,
page_nam,
updt_dt_tm,
ROW_NUMBER() OVER (PARTITION BY prj_id ORDER BY updt_dt_tm DESC) AS rn
FROM
mmo_project_meta_data
)
SELECT
p.prj_id,
p.prj_nam AS prj_nam,
lu.page_nam,
lu.updt_dt_tm
FROM
LatestUpdates lu
RIGHT JOIN
mmo_projects p ON lu.prj_id = p.prj_id
WHERE
p.prj_ownr_id = ? AND lu.rn = 1
"""
params = (st.session_state["emp_id"],) # Parameters for the SQL query
# Execute the query and retrieve project summary data
project_summary = query_excecuter_postgres(
query, db_cred, params=params, insert=False
)
# Update the session state with the project summary dataframe
st.session_state["project_summary_df"] = pd.DataFrame(
project_summary,
columns=[
"Project Number",
"Project Name",
"Last Modified Page",
"Last Modified Time",
],
)
st.session_state["project_summary_df"] = st.session_state[
"project_summary_df"
].sort_values(by=["Last Modified Time"], ascending=False)
return st.session_state["project_summary_df"]
from constants import default_dct
def ensure_project_dct_structure(session_state, default_dct):
for key, value in default_dct.items():
if key not in session_state:
session_state[key] = value
elif isinstance(value, dict):
ensure_project_dct_structure(session_state[key], value)
def project_selection():
emp_id = st.text_input("employee id", key="emp1111").lower()
password = st.text_input("Password", max_chars=15, type="password")
if st.button("Login"):
if "unique_ids" not in st.session_state:
unique_users_query = f"""
SELECT DISTINCT emp_id, emp_nam, emp_typ from mmo_users;
"""
unique_users_result = query_excecuter_postgres(
unique_users_query, db_cred, insert=False
) # retrieves all the users who has access to MMO TOOL
st.session_state["unique_ids"] = {
emp_id: (emp_nam, emp_type)
for emp_id, emp_nam, emp_type in unique_users_result
}
if emp_id not in st.session_state["unique_ids"].keys() or len(password) == 0:
st.warning("invalid id or password!")
st.stop()
if not is_pswrd_flag_set(emp_id):
st.warning("Reset password in home page to continue")
st.stop()
elif not verify_password(emp_id, password):
st.warning("Invalid user name or password")
st.stop()
else:
st.session_state["emp_id"] = emp_id
st.session_state["username"] = st.session_state["unique_ids"][
st.session_state["emp_id"]
][0]
with st.spinner("Loading Saved Projects"):
st.session_state["project_summary_df"] = update_summary_df()
# st.write(st.session_state["project_name"][0])
if len(st.session_state["project_summary_df"]) == 0:
st.warning("No projects found please create a project in Home page")
st.stop()
else:
try:
st.session_state["project_name"] = (
st.session_state["project_summary_df"]
.loc[
st.session_state["project_summary_df"]["Project Number"]
== st.session_state["project_summary_df"].iloc[0, 0],
"Project Name",
]
.values[0]
) # fetching project name from project number stored in summary df
poroject_dct_query = f"""
SELECT pkl_obj FROM mmo_project_meta_data WHERE prj_id = ? AND file_nam=?;
"""
# Execute the query and retrieve the result
project_number = int(st.session_state["project_summary_df"].iloc[0, 0])
st.session_state["project_number"] = project_number
project_dct_retrieved = query_excecuter_postgres(
poroject_dct_query,
db_cred,
params=(project_number, "project_dct"),
insert=False,
)
# retrieves project dict (meta data) stored in db
st.session_state["project_dct"] = pickle.loads(
project_dct_retrieved[0][0]
) # converting bytes data to original objet using pickle
ensure_project_dct_structure(
st.session_state["project_dct"], default_dct
)
st.success("Project Loded")
st.rerun()
except Exception as e:
st.write(
"Failed to load project meta data from db please create new project!"
)
st.stop()
def update_db(prj_id, page_nam, file_nam, pkl_obj, resp_mtrc="", schema=""):
# Check if an entry already exists
check_query = f"""
SELECT 1 FROM mmo_project_meta_data
WHERE prj_id = ? AND file_nam =?;
"""
check_params = (prj_id, file_nam)
result = query_excecuter_postgres(
check_query, db_cred, params=check_params, insert=False
)
# If entry exists, perform an update
if result is not None and result:
update_query = f"""
UPDATE mmo_project_meta_data
SET file_nam = ?, pkl_obj = ?, page_nam=? ,updt_dt_tm = datetime('now')
WHERE prj_id = ? AND file_nam = ?;
"""
update_params = (file_nam, pkl_obj, page_nam, prj_id, file_nam)
query_excecuter_postgres(
update_query, db_cred, params=update_params, insert=True
)
# If entry does not exist, perform an insert
else:
insert_query = f"""
INSERT INTO mmo_project_meta_data
(prj_id, page_nam, file_nam, pkl_obj,crte_by_uid, crte_dt_tm, updt_dt_tm)
VALUES (?, ?, ?, ?, ?, datetime('now'), datetime('now'));
"""
insert_params = (
prj_id,
page_nam,
file_nam,
pkl_obj,
st.session_state["emp_id"],
)
query_excecuter_postgres(
insert_query, db_cred, params=insert_params, insert=True
)
# st.success(f"Inserted project meta data for project {prj_id}, page {page_nam}")
def retrieve_pkl_object(prj_id, page_nam, file_nam, schema=""):
query = f"""
SELECT pkl_obj FROM mmo_project_meta_data
WHERE prj_id = ? AND page_nam = ? AND file_nam = ?;
"""
params = (prj_id, page_nam, file_nam)
result = query_excecuter_postgres(
query, db_cred=db_cred, params=params, insert=False
)
if result and result[0] and result[0][0]:
pkl_obj = result[0][0]
# Deserialize the pickle object
return pickle.loads(pkl_obj)
else:
return None
def validate_text(input_text):
# Check the length of the text
if len(input_text) < 2:
return False, "Input should be at least 2 characters long."
if len(input_text) > 30:
return False, "Input should not exceed 30 characters."
# Check if the text contains only allowed characters
if not re.match(r"^[A-Za-z0-9_]+$", input_text):
return (
False,
"Input contains invalid characters. Only letters, numbers and underscores are allowed.",
)
return True, "Input is valid."
def delete_entries(prj_id, page_names, db_cred=None, schema=None):
"""
Deletes all entries from the project_meta_data table based on prj_id and a list of page names.
Parameters:
prj_id (int): The project ID.
page_names (list): A list of page names.
db_cred (dict): Database credentials with keys 'dbname', 'user', 'password', 'host', 'port'.
schema (str): The schema name.
"""
# Create placeholders for each page name in the list
placeholders = ", ".join(["?"] * len(page_names))
query = f"""
DELETE FROM mmo_project_meta_data
WHERE prj_id = ? AND page_nam IN ({placeholders});
"""
# Combine prj_id and page_names into one list of parameters
params = (prj_id, *page_names)
query_excecuter_postgres(query, db_cred, params=params, insert=True)
# st.success(f"Deleted entries for project {prj_id}, page {page_name}")
def store_hashed_password(
user_id,
plain_text_password,
):
"""
Hashes a plain text password using bcrypt, converts it to a UTF-8 string, and stores it as text.
Parameters:
plain_text_password (str): The plain text password to be hashed.
db_cred (dict): The database credentials including dbname, user, password, host, and port.
"""
# Hash the plain text password
hashed_password = bcrypt.hashpw(
plain_text_password.encode("utf-8"), bcrypt.gensalt()
)
# Convert the byte string to a regular string for storage
hashed_password_str = hashed_password.decode("utf-8")
# SQL query to update the pswrd_key for the specified user_id
query = f"""
UPDATE mmo_users
SET pswrd_key = ?
WHERE emp_id = ?;
"""
# Execute the query using the existing query_excecuter_postgres function
query_excecuter_postgres(
query=query, db_cred=db_cred, params=(hashed_password_str, user_id), insert=True
)
def verify_password(user_id, plain_text_password):
"""
Verifies the plain text password against the stored hashed password for the specified user_id.
Parameters:
user_id (int): The ID of the user whose password is being verified.
plain_text_password (str): The plain text password to verify.
db_cred (dict): The database credentials including dbname, user, password, host, and port.
"""
# SQL query to retrieve the hashed password for the user_id
query = f"""
SELECT pswrd_key FROM mmo_users WHERE emp_id = ?;
"""
# Execute the query using the existing query_excecuter_postgres function
result = query_excecuter_postgres(
query=query, db_cred=db_cred, params=(user_id,), insert=False
)
if result:
stored_hashed_password_str = result[0][0]
# Convert the stored string back to bytes
stored_hashed_password = stored_hashed_password_str.encode("utf-8")
if bcrypt.checkpw(plain_text_password.encode("utf-8"), stored_hashed_password):
return True
else:
return False
else:
return False
def update_password_in_db(user_id, plain_text_password):
"""
Hashes the plain text password and updates the `pswrd_key`
column for the given `emp_id` in the `mmo_users` table.
Parameters:
emp_id (var): The ID of the user whose password needs to be updated.
plain_text_password (str): The plain text password to be hashed and stored.
db_cred (dict): Database credentials required to connect to the database.
"""
# Hash the plain text password using bcrypt
hashed_password = bcrypt.hashpw(
plain_text_password.encode("utf-8"), bcrypt.gensalt()
)
# Convert the hashed password from bytes to a string for storage
hashed_password_str = hashed_password.decode("utf-8")
# SQL query to update the password in the database
query = f"""
UPDATE mmo_users
SET pswrd_key = ?
WHERE emp_id = ?
"""
# Parameters for the query
params = (hashed_password_str, user_id)
# Execute the query using the query_excecuter_postgres function
query_excecuter_postgres(query, db_cred, params=params, insert=True)
def is_pswrd_flag_set(user_id):
query = f"""
SELECT pswrd_flag
FROM mmo_users
WHERE emp_id = ?;
"""
# Execute the query
result = query_excecuter_postgres(query, db_cred, params=(user_id,), insert=False)
# Return True if the flag is 1, otherwise return False
if result and result[0][0] == 1:
return True
else:
return False
def set_pswrd_flag(user_id):
query = f"""
UPDATE mmo_users
SET pswrd_flag = 1
WHERE emp_id = ?;
"""
# Execute the update query
query_excecuter_postgres(query, db_cred, params=(user_id,), insert=True)
def retrieve_pkl_object_without_warning(prj_id, page_nam, file_nam, schema):
query = f"""
SELECT pkl_obj FROM mmo_project_meta_data
WHERE prj_id = ? AND page_nam = ? AND file_nam = ?;
"""
params = (prj_id, page_nam, file_nam)
result = query_excecuter_postgres(
query, db_cred=db_cred, params=params, insert=False
)
if result and result[0] and result[0][0]:
pkl_obj = result[0][0]
# Deserialize the pickle object
return pickle.loads(pkl_obj)
else:
# st.warning(
# "Pickle object not found for the given project ID, page name, and file name."
# )
return None
color_palette = [
"#F3F3F0",
"#5E7D7E",
"#2FA1FF",
"#00EDED",
"#00EAE4",
"#304550",
"#EDEBEB",
"#7FBEFD",
"#003059",
"#A2F3F3",
"#E1D6E2",
"#B6B6B6",
]
CURRENCY_INDICATOR = "$"
# database_file = r"DB/User.db"
# conn = sqlite3.connect(database_file, check_same_thread=False) # connection with sql db
# c = conn.cursor()
# def load_authenticator():
# with open("config.yaml") as file:
# config = yaml.load(file, Loader=SafeLoader)
# st.session_state["config"] = config
# authenticator = stauth.Authenticate(
# credentials=config["credentials"],
# cookie_name=config["cookie"]["name"],
# key=config["cookie"]["key"],
# cookie_expiry_days=config["cookie"]["expiry_days"],
# preauthorized=config["preauthorized"],
# )
# st.session_state["authenticator"] = authenticator
# return authenticator
# Authentication
# def authenticator():
# for k, v in st.session_state.items():
# if k not in ["logout", "login", "config"] and not k.startswith(
# "FormSubmitter"
# ):
# st.session_state[k] = v
# with open("config.yaml") as file:
# config = yaml.load(file, Loader=SafeLoader)
# st.session_state["config"] = config
# authenticator = stauth.Authenticate(
# config["credentials"],
# config["cookie"]["name"],
# config["cookie"]["key"],
# config["cookie"]["expiry_days"],
# config["preauthorized"],
# )
# st.session_state["authenticator"] = authenticator
# name, authentication_status, username = authenticator.login(
# "Login", "main"
# )
# auth_status = st.session_state.get("authentication_status")
# if auth_status == True:
# authenticator.logout("Logout", "main")
# is_state_initiaized = st.session_state.get("initialized", False)
# if not is_state_initiaized:
# if "session_name" not in st.session_state:
# st.session_state["session_name"] = None
# return name
# def authentication():
# with open("config.yaml") as file:
# config = yaml.load(file, Loader=SafeLoader)
# authenticator = stauth.Authenticate(
# config["credentials"],
# config["cookie"]["name"],
# config["cookie"]["key"],
# config["cookie"]["expiry_days"],
# config["preauthorized"],
# )
# name, authentication_status, username = authenticator.login(
# "Login", "main"
# )
# return authenticator, name, authentication_status, username
def nav_page(page_name, timeout_secs=3):
nav_script = """
""" % (
page_name,
timeout_secs,
)
html(nav_script)
# def load_local_css(file_name):
# with open(file_name) as f:
# st.markdown(f'', unsafe_allow_html=True)
# def set_header():
# return st.markdown(f"""
""",
unsafe_allow_html=True,
)
# def set_header():
# logo_path = "./path/to/your/local/LIME_logo.png" # Replace with the actual file path
# text = "LiME"
# return st.markdown(f"""
#

#
{text}
#
""", unsafe_allow_html=True)
def s_curve(x, K, b, a, x0):
return K / (1 + b * np.exp(-a * (x - x0)))
def panel_level(input_df, date_column="Date"):
# Ensure 'Date' is set as the index
if date_column not in input_df.index.names:
input_df = input_df.set_index(date_column)
# Select numeric columns only (excluding 'Date' since it's now the index)
numeric_columns_df = input_df.select_dtypes(include="number")
# Group by 'Date' (which is the index) and sum the numeric columns
aggregated_df = numeric_columns_df.groupby(input_df.index).sum()
# Reset the index to bring the 'Date' column
aggregated_df = aggregated_df.reset_index()
return aggregated_df
def fetch_actual_data(
panel=None,
target_file="Overview_data_test.xlsx",
updated_rcs=None,
metrics=None,
):
excel = pd.read_excel(Path(target_file), sheet_name=None)
# Extract dataframes for raw data, spend input, and contribution MMM
raw_df = excel["RAW DATA MMM"]
spend_df = excel["SPEND INPUT"]
contri_df = excel["CONTRIBUTION MMM"]
# Check if the panel is not None
if panel is not None and panel != "Aggregated":
raw_df = raw_df[raw_df["Panel"] == panel].drop(columns=["Panel"])
spend_df = spend_df[spend_df["Panel"] == panel].drop(columns=["Panel"])
contri_df = contri_df[contri_df["Panel"] == panel].drop(columns=["Panel"])
elif panel == "Aggregated":
raw_df = panel_level(raw_df, date_column="Date")
spend_df = panel_level(spend_df, date_column="Week")
contri_df = panel_level(contri_df, date_column="Date")
# Revenue_df = excel['Revenue']
## remove sesonalities, indices etc ...
unnamed_cols = [col for col in raw_df.columns if col.lower().startswith("unnamed")]
## remove sesonalities, indices etc ...
exclude_columns = [
"Date",
"Region",
"Controls_Grammarly_Index_SeasonalAVG",
"Controls_Quillbot_Index",
"Daily_Positive_Outliers",
"External_RemoteClass_Index",
"Intervals ON 20190520-20190805 | 20200518-20200803 | 20210517-20210802",
"Intervals ON 20190826-20191209 | 20200824-20201207 | 20210823-20211206",
"Intervals ON 20201005-20201019",
"Promotion_PercentOff",
"Promotion_TimeBased",
"Seasonality_Indicator_Chirstmas",
"Seasonality_Indicator_NewYears_Days",
"Seasonality_Indicator_Thanksgiving",
"Trend 20200302 / 20200803",
] + unnamed_cols
raw_df["Date"] = pd.to_datetime(raw_df["Date"])
contri_df["Date"] = pd.to_datetime(contri_df["Date"])
input_df = raw_df.sort_values(by="Date")
output_df = contri_df.sort_values(by="Date")
spend_df["Week"] = pd.to_datetime(
spend_df["Week"], format="%Y-%m-%d", errors="coerce"
)
spend_df.sort_values(by="Week", inplace=True)
# spend_df['Week'] = pd.to_datetime(spend_df['Week'], errors='coerce')
# spend_df = spend_df.sort_values(by='Week')
channel_list = [col for col in input_df.columns if col not in exclude_columns]
channel_list = list(set(channel_list) - set(["fb_level_achieved_tier_1", "ga_app"]))
infeasible_channels = [
c
for c in contri_df.select_dtypes(include=["float", "int"]).columns
if contri_df[c].sum() <= 0
]
# st.write(channel_list)
channel_list = list(set(channel_list) - set(infeasible_channels))
upper_limits = {}
output_cols = []
actual_output_dic = {}
actual_input_dic = {}
for inp_col in channel_list:
# st.write(inp_col)
spends = input_df[inp_col].values
x = spends.copy()
# upper limit for penalty
upper_limits[inp_col] = 2 * x.max()
# contribution
# out_col = [_col for _col in output_df.columns if _col.startswith(inp_col)][0]
out_col = inp_col
y = output_df[out_col].values.copy()
actual_output_dic[inp_col] = y.copy()
actual_input_dic[inp_col] = x.copy()
##output cols aggregation
output_cols.append(out_col)
return pd.DataFrame(actual_input_dic), pd.DataFrame(actual_output_dic)
# Function to initialize model results data
def initialize_data(panel=None, metrics=None):
# Extract dataframes for raw data, spend input, and contribution data
raw_df = st.session_state["project_dct"]["current_media_performance"][
"model_outputs"
][metrics]["raw_data"].copy()
spend_df = st.session_state["project_dct"]["current_media_performance"][
"model_outputs"
][metrics]["spends_data"].copy()
contribution_df = st.session_state["project_dct"]["current_media_performance"][
"model_outputs"
][metrics]["contribution_data"].copy()
# Check if 'Panel' or 'panel' is in the columns
panel_column = None
if "Panel" in raw_df.columns:
panel_column = "Panel"
elif "panel" in raw_df.columns:
panel_column = "panel"
# Filter data by panel if provided
if panel and panel.lower() != "aggregated":
raw_df = raw_df[raw_df[panel_column] == panel].drop(columns=[panel_column])
spend_df = spend_df[spend_df[panel_column] == panel].drop(
columns=[panel_column]
)
contribution_df = contribution_df[contribution_df[panel_column] == panel].drop(
columns=[panel_column]
)
else:
raw_df = panel_level(raw_df, date_column="Date")
spend_df = panel_level(spend_df, date_column="Date")
contribution_df = panel_level(contribution_df, date_column="Date")
# Remove unnecessary columns
unnamed_cols = [col for col in raw_df.columns if col.lower().startswith("unnamed")]
exclude_columns = ["Date"] + unnamed_cols
# Convert Date columns to datetime
for df in [raw_df, spend_df, contribution_df]:
df["Date"] = pd.to_datetime(df["Date"], format="%Y-%m-%d", errors="coerce")
# Sort data by Date
input_df = raw_df.sort_values(by="Date")
contribution_df = contribution_df.sort_values(by="Date")
spend_df.sort_values(by="Date", inplace=True)
# Extract channels excluding unwanted columns
channel_list = [col for col in input_df.columns if col not in exclude_columns]
# Filter out channels with non-positive contributions
negative_contributions = [
col
for col in contribution_df.select_dtypes(include=["float", "int"]).columns
if contribution_df[col].sum() <= 0
]
channel_list = list(set(channel_list) - set(negative_contributions))
# Initialize dictionaries for metrics and response curves
response_curves, mapes, rmses, upper_limits = {}, {}, {}, {}
r2_scores, powers, conversion_rates, actual_output, actual_input = (
{},
{},
{},
{},
{},
)
channels = {}
sales = None
dates = input_df["Date"].values
# Fit s-curve for each channel
for channel in channel_list:
spends = input_df[channel].values
x = spends.copy()
upper_limits[channel] = 2 * x.max()
# Get corresponding output column
output_col = [
_col for _col in contribution_df.columns if _col.startswith(channel)
][0]
y = contribution_df[output_col].values.copy()
actual_output[channel] = y.copy()
actual_input[channel] = x.copy()
# Scale input data
power = np.ceil(np.log(x.max()) / np.log(10)) - 3
if power >= 0:
x = x / 10**power
x, y = x.astype("float64"), y.astype("float64")
# Set bounds for curve fitting
if y.max() <= 0.01:
bounds = (
(0, 0, 0, 0),
(3 * 0.01, 1000, 1, x.max() if x.max() > 0 else 0.01),
)
else:
bounds = ((0, 0, 0, 0), (3 * y.max(), 1000, 1, x.max()))
# Set y to 0 where x is 0
y[x == 0] = 0
# Fit s-curve and calculate metrics
# params, _ = curve_fit(
# s_curve,
# x
# y,
# p0=(2 * y.max(), 0.01, 1e-5, x.max()),
# bounds=bounds,
# maxfev=int(1e6),
# )
params, _ = curve_fit(
s_curve,
list(x) + [0] * len(x),
list(y) + [0] * len(y),
p0=(2 * y.max(), 0.01, 1e-5, x.max()),
bounds=bounds,
maxfev=int(1e6),
)
mape = (100 * abs(1 - s_curve(x, *params) / y.clip(min=1))).mean()
rmse = np.sqrt(((y - s_curve(x, *params)) ** 2).mean())
r2_score_ = r2_score(y, s_curve(x, *params))
# Store metrics and parameters
response_curves[channel] = {
"K": params[0],
"b": params[1],
"a": params[2],
"x0": params[3],
}
mapes[channel] = mape
rmses[channel] = rmse
r2_scores[channel] = r2_score_
powers[channel] = power
conversion_rate = spend_df[channel].sum() / max(input_df[channel].sum(), 1e-9)
conversion_rates[channel] = conversion_rate
correction = y - s_curve(x, *params)
# Initialize Channel object
channel_obj = Channel(
name=channel,
dates=dates,
spends=spends,
conversion_rate=conversion_rate,
response_curve_type="s-curve",
response_curve_params={
"K": params[0],
"b": params[1],
"a": params[2],
"x0": params[3],
},
bounds=np.array([-10, 10]),
correction=correction,
)
channels[channel] = channel_obj
if sales is None:
sales = channel_obj.actual_sales
else:
sales += channel_obj.actual_sales
# Calculate other contributions
other_contributions = (
contribution_df.drop(columns=[*response_curves.keys()])
.sum(axis=1, numeric_only=True)
.values
)
# Initialize Scenario object
scenario = Scenario(
name="default",
channels=channels,
constant=other_contributions,
correction=np.array([]),
)
# Set session state variables
st.session_state.update(
{
"initialized": True,
"actual_df": input_df,
"raw_df": raw_df,
"contri_df": contribution_df,
"default_scenario_dict": class_to_dict(scenario),
"scenario": scenario,
"channels_list": channel_list,
"optimization_channels": {
channel_name: False for channel_name in channel_list
},
"rcs": response_curves.copy(),
"powers": powers,
"actual_contribution_df": pd.DataFrame(actual_output),
"actual_input_df": pd.DataFrame(actual_input),
"xlsx_buffer": io.BytesIO(),
"saved_scenarios": (
pickle.load(open("../saved_scenarios.pkl", "rb"))
if Path("../saved_scenarios.pkl").exists()
else OrderedDict()
),
"disable_download_button": True,
}
)
for channel in channels.values():
st.session_state[channel.name] = numerize(
channel.actual_total_spends * channel.conversion_rate, 1
)
# Prepare response curve data for output
response_curve_data = {}
for channel, params in st.session_state["rcs"].items():
x = st.session_state["actual_input_df"][channel].values.astype(float)
y = st.session_state["actual_contribution_df"][channel].values.astype(float)
power = float(np.ceil(np.log(max(x)) / np.log(10)) - 3)
x_plot = list(np.linspace(0, 5 * max(x), 100))
response_curve_data[channel] = {
"K": float(params["K"]),
"b": float(params["b"]),
"a": float(params["a"]),
"x0": float(params["x0"]),
"power": power,
"x": list(x),
"y": list(y),
"x_plot": x_plot,
}
return response_curve_data, scenario
# def initialize_data(panel=None, metrics=None):
# # Extract dataframes for raw data, spend input, and contribution data
# raw_df = st.session_state["project_dct"]["current_media_performance"][
# "model_outputs"
# ][metrics]["raw_data"]
# spend_df = st.session_state["project_dct"]["current_media_performance"][
# "model_outputs"
# ][metrics]["spends_data"]
# contri_df = st.session_state["project_dct"]["current_media_performance"][
# "model_outputs"
# ][metrics]["contribution_data"]
# # Check if the panel is not None
# if panel is not None and panel.lower() != "aggregated":
# raw_df = raw_df[raw_df["Panel"] == panel].drop(columns=["Panel"])
# spend_df = spend_df[spend_df["Panel"] == panel].drop(columns=["Panel"])
# contri_df = contri_df[contri_df["Panel"] == panel].drop(columns=["Panel"])
# elif panel.lower() == "aggregated":
# raw_df = panel_level(raw_df, date_column="Date")
# spend_df = panel_level(spend_df, date_column="Date")
# contri_df = panel_level(contri_df, date_column="Date")
# ## remove sesonalities, indices etc ...
# unnamed_cols = [col for col in raw_df.columns if col.lower().startswith("unnamed")]
# ## remove sesonalities, indices etc ...
# exclude_columns = ["Date"] + unnamed_cols
# raw_df["Date"] = pd.to_datetime(raw_df["Date"], format="%Y-%m-%d", errors="coerce")
# contri_df["Date"] = pd.to_datetime(
# contri_df["Date"], format="%Y-%m-%d", errors="coerce"
# )
# spend_df["Date"] = pd.to_datetime(
# spend_df["Date"], format="%Y-%m-%d", errors="coerce"
# )
# input_df = raw_df.sort_values(by="Date")
# output_df = contri_df.sort_values(by="Date")
# spend_df.sort_values(by="Date", inplace=True)
# channel_list = [col for col in input_df.columns if col not in exclude_columns]
# negative_contribution = [
# c
# for c in contri_df.select_dtypes(include=["float", "int"]).columns
# if contri_df[c].sum() <= 0
# ]
# channel_list = list(set(channel_list) - set(negative_contribution))
# response_curves = {}
# mapes = {}
# rmses = {}
# upper_limits = {}
# powers = {}
# r2 = {}
# conv_rates = {}
# output_cols = []
# channels = {}
# sales = None
# dates = input_df.Date.values
# actual_output_dic = {}
# actual_input_dic = {}
# for inp_col in channel_list:
# spends = input_df[inp_col].values
# x = spends.copy()
# # upper limit for penalty
# upper_limits[inp_col] = 2 * x.max()
# # contribution
# out_col = [_col for _col in output_df.columns if _col.startswith(inp_col)][0]
# y = output_df[out_col].values.copy()
# actual_output_dic[inp_col] = y.copy()
# actual_input_dic[inp_col] = x.copy()
# ##output cols aggregation
# output_cols.append(out_col)
# ## scale the input
# power = np.ceil(np.log(x.max()) / np.log(10)) - 3
# if power >= 0:
# x = x / 10**power
# x = x.astype("float64")
# y = y.astype("float64")
# if y.max() <= 0.01:
# if x.max() <= 0.0:
# bounds = ((0, 0, 0, 0), (3 * 0.01, 1000, 1, 0.01))
# else:
# bounds = ((0, 0, 0, 0), (3 * 0.01, 1000, 1, x.max()))
# else:
# bounds = ((0, 0, 0, 0), (3 * y.max(), 1000, 1, x.max()))
# params, _ = curve_fit(
# s_curve,
# x,
# y,
# p0=(2 * y.max(), 0.01, 1e-5, x.max()),
# bounds=bounds,
# maxfev=int(1e5),
# )
# mape = (100 * abs(1 - s_curve(x, *params) / y.clip(min=1))).mean()
# rmse = np.sqrt(((y - s_curve(x, *params)) ** 2).mean())
# r2_ = r2_score(y, s_curve(x, *params))
# response_curves[inp_col] = {
# "K": params[0],
# "b": params[1],
# "a": params[2],
# "x0": params[3],
# }
# mapes[inp_col] = mape
# rmses[inp_col] = rmse
# r2[inp_col] = r2_
# powers[inp_col] = power
# conv = spend_df[inp_col].sum() / max(input_df[inp_col].sum(), 1e-9)
# conv_rates[inp_col] = conv
# correction = y - s_curve(x, *params)
# channel = Channel(
# name=inp_col,
# dates=dates,
# spends=spends,
# conversion_rate=conv_rates[inp_col],
# response_curve_type="s-curve",
# response_curve_params={
# "K": params[0],
# "b": params[1],
# "a": params[2],
# "x0": params[3],
# },
# bounds=np.array([-10, 10]),
# correction=correction,
# )
# channels[inp_col] = channel
# if sales is None:
# sales = channel.actual_sales
# else:
# sales += channel.actual_sales
# other_contributions = (
# output_df.drop([*output_cols], axis=1).sum(axis=1, numeric_only=True).values
# )
# scenario = Scenario(
# name="default",
# channels=channels,
# constant=other_contributions,
# correction=np.array([]),
# )
# ## setting session variables
# st.session_state["initialized"] = True
# st.session_state["actual_df"] = input_df
# st.session_state["raw_df"] = raw_df
# st.session_state["contri_df"] = output_df
# default_scenario_dict = class_to_dict(scenario)
# st.session_state["default_scenario_dict"] = default_scenario_dict
# st.session_state["scenario"] = scenario
# st.session_state["channels_list"] = channel_list
# st.session_state["optimization_channels"] = {
# channel_name: False for channel_name in channel_list
# }
# st.session_state["rcs"] = response_curves.copy()
# st.session_state["powers"] = powers
# st.session_state["actual_contribution_df"] = pd.DataFrame(actual_output_dic)
# st.session_state["actual_input_df"] = pd.DataFrame(actual_input_dic)
# for channel in channels.values():
# st.session_state[channel.name] = numerize(
# channel.actual_total_spends * channel.conversion_rate, 1
# )
# st.session_state["xlsx_buffer"] = io.BytesIO()
# if Path("../saved_scenarios.pkl").exists():
# with open("../saved_scenarios.pkl", "rb") as f:
# st.session_state["saved_scenarios"] = pickle.load(f)
# else:
# st.session_state["saved_scenarios"] = OrderedDict()
# # st.session_state["total_spends_change"] = 0
# st.session_state["optimization_channels"] = {
# channel_name: False for channel_name in channel_list
# }
# st.session_state["disable_download_button"] = True
# rcs_data = {}
# for channel in st.session_state["rcs"]:
# # Convert to native Python lists and types
# x = list(st.session_state["actual_input_df"][channel].values.astype(float))
# y = list(
# st.session_state["actual_contribution_df"][channel].values.astype(float)
# )
# power = float(np.ceil(np.log(max(x)) / np.log(10)) - 3)
# x_plot = list(np.linspace(0, 5 * max(x), 100))
# rcs_data[channel] = {
# "K": float(st.session_state["rcs"][channel]["K"]),
# "b": float(st.session_state["rcs"][channel]["b"]),
# "a": float(st.session_state["rcs"][channel]["a"]),
# "x0": float(st.session_state["rcs"][channel]["x0"]),
# "power": power,
# "x": x,
# "y": y,
# "x_plot": x_plot,
# }
# return rcs_data, scenario
# def initialize_data():
# # fetch data from excel
# output = pd.read_excel('data.xlsx',sheet_name=None)
# raw_df = output['RAW DATA MMM']
# contribution_df = output['CONTRIBUTION MMM']
# Revenue_df = output['Revenue']
# ## channels to be shows
# channel_list = []
# for col in raw_df.columns:
# if 'click' in col.lower() or 'spend' in col.lower() or 'imp' in col.lower():
# channel_list.append(col)
# else:
# pass
# ## NOTE : Considered only Desktop spends for all calculations
# acutal_df = raw_df[raw_df.Region == 'Desktop'].copy()
# ## NOTE : Considered one year of data
# acutal_df = acutal_df[acutal_df.Date>'2020-12-31']
# actual_df = acutal_df.drop('Region',axis=1).sort_values(by='Date')[[*channel_list,'Date']]
# ##load response curves
# with open('./grammarly_response_curves.json','r') as f:
# response_curves = json.load(f)
# ## create channel dict for scenario creation
# dates = actual_df.Date.values
# channels = {}
# rcs = {}
# constant = 0.
# for i,info_dict in enumerate(response_curves):
# name = info_dict.get('name')
# response_curve_type = info_dict.get('response_curve')
# response_curve_params = info_dict.get('params')
# rcs[name] = response_curve_params
# if name != 'constant':
# spends = actual_df[name].values
# channel = Channel(name=name,dates=dates,
# spends=spends,
# response_curve_type=response_curve_type,
# response_curve_params=response_curve_params,
# bounds=np.array([-30,30]))
# channels[name] = channel
# else:
# constant = info_dict.get('value',0.) * len(dates)
# ## create scenario
# scenario = Scenario(name='default', channels=channels, constant=constant)
# default_scenario_dict = class_to_dict(scenario)
# ## setting session variables
# st.session_state['initialized'] = True
# st.session_state['actual_df'] = actual_df
# st.session_state['raw_df'] = raw_df
# st.session_state['default_scenario_dict'] = default_scenario_dict
# st.session_state['scenario'] = scenario
# st.session_state['channels_list'] = channel_list
# st.session_state['optimization_channels'] = {channel_name : False for channel_name in channel_list}
# st.session_state['rcs'] = rcs
# for channel in channels.values():
# if channel.name not in st.session_state:
# st.session_state[channel.name] = float(channel.actual_total_spends)
# if 'xlsx_buffer' not in st.session_state:
# st.session_state['xlsx_buffer'] = io.BytesIO()
# ## for saving scenarios
# if 'saved_scenarios' not in st.session_state:
# if Path('../saved_scenarios.pkl').exists():
# with open('../saved_scenarios.pkl','rb') as f:
# st.session_state['saved_scenarios'] = pickle.load(f)
# else:
# st.session_state['saved_scenarios'] = OrderedDict()
# if 'total_spends_change' not in st.session_state:
# st.session_state['total_spends_change'] = 0
# if 'optimization_channels' not in st.session_state:
# st.session_state['optimization_channels'] = {channel_name : False for channel_name in channel_list}
# if 'disable_download_button' not in st.session_state:
# st.session_state['disable_download_button'] = True
def create_channel_summary(scenario):
# Provided data
data = {
"Channel": [
"Paid Search",
"Ga will cid baixo risco",
"Digital tactic others",
"Fb la tier 1",
"Fb la tier 2",
"Paid social others",
"Programmatic",
"Kwai",
"Indicacao",
"Infleux",
"Influencer",
],
"Spends": [
"$ 11.3K",
"$ 155.2K",
"$ 50.7K",
"$ 125.4K",
"$ 125.2K",
"$ 105K",
"$ 3.3M",
"$ 47.5K",
"$ 55.9K",
"$ 632.3K",
"$ 48.3K",
],
"Revenue": [
"558.0K",
"3.5M",
"5.2M",
"3.1M",
"3.1M",
"2.1M",
"20.8M",
"1.6M",
"728.4K",
"22.9M",
"4.8M",
],
}
# Create DataFrame
df = pd.DataFrame(data)
# Convert currency strings to numeric values
df["Spends"] = (
df["Spends"]
.replace({"\$": "", "K": "*1e3", "M": "*1e6"}, regex=True)
.map(pd.eval)
.astype(int)
)
df["Revenue"] = (
df["Revenue"]
.replace({"\$": "", "K": "*1e3", "M": "*1e6"}, regex=True)
.map(pd.eval)
.astype(int)
)
# Calculate ROI
df["ROI"] = (df["Revenue"] - df["Spends"]) / df["Spends"]
# Format columns
format_currency = lambda x: f"${x:,.1f}"
format_roi = lambda x: f"{x:.1f}"
df["Spends"] = [
"$ 11.3K",
"$ 155.2K",
"$ 50.7K",
"$ 125.4K",
"$ 125.2K",
"$ 105K",
"$ 3.3M",
"$ 47.5K",
"$ 55.9K",
"$ 632.3K",
"$ 48.3K",
]
df["Revenue"] = [
"$ 536.3K",
"$ 3.4M",
"$ 5M",
"$ 3M",
"$ 3M",
"$ 2M",
"$ 20M",
"$ 1.5M",
"$ 7.1M",
"$ 22M",
"$ 4.6M",
]
df["ROI"] = df["ROI"].apply(format_roi)
return df
# @st.cache(allow_output_mutation=True)
# def create_contribution_pie(scenario):
# #c1f7dc
# colors_map = {col:color for col,color in zip(st.session_state['channels_list'],plotly.colors.n_colors(plotly.colors.hex_to_rgb('#BE6468'), plotly.colors.hex_to_rgb('#E7B8B7'),23))}
# total_contribution_fig = make_subplots(rows=1, cols=2,subplot_titles=['Spends','Revenue'],specs=[[{"type": "pie"}, {"type": "pie"}]])
# total_contribution_fig.add_trace(
# go.Pie(labels=[channel_name_formating(channel_name) for channel_name in st.session_state['channels_list']] + ['Non Media'],
# values= [round(scenario.channels[channel_name].actual_total_spends * scenario.channels[channel_name].conversion_rate,1) for channel_name in st.session_state['channels_list']] + [0],
# marker=dict(colors = [plotly.colors.label_rgb(colors_map[channel_name]) for channel_name in st.session_state['channels_list']] + ['#F0F0F0']),
# hole=0.3),
# row=1, col=1)
# total_contribution_fig.add_trace(
# go.Pie(labels=[channel_name_formating(channel_name) for channel_name in st.session_state['channels_list']] + ['Non Media'],
# values= [scenario.channels[channel_name].actual_total_sales for channel_name in st.session_state['channels_list']] + [scenario.correction.sum() + scenario.constant.sum()],
# hole=0.3),
# row=1, col=2)
# total_contribution_fig.update_traces(textposition='inside',texttemplate='%{percent:.1%}')
# total_contribution_fig.update_layout(uniformtext_minsize=12,title='Channel contribution', uniformtext_mode='hide')
# return total_contribution_fig
# @st.cache(allow_output_mutation=True)
# def create_contribuion_stacked_plot(scenario):
# weekly_contribution_fig = make_subplots(rows=1, cols=2,subplot_titles=['Spends','Revenue'],specs=[[{"type": "bar"}, {"type": "bar"}]])
# raw_df = st.session_state['raw_df']
# df = raw_df.sort_values(by='Date')
# x = df.Date
# weekly_spends_data = []
# weekly_sales_data = []
# for channel_name in st.session_state['channels_list']:
# weekly_spends_data.append((go.Bar(x=x,
# y=scenario.channels[channel_name].actual_spends * scenario.channels[channel_name].conversion_rate,
# name=channel_name_formating(channel_name),
# hovertemplate="Date:%{x}
Spend:%{y:$.2s}",
# legendgroup=channel_name)))
# weekly_sales_data.append((go.Bar(x=x,
# y=scenario.channels[channel_name].actual_sales,
# name=channel_name_formating(channel_name),
# hovertemplate="Date:%{x}
Revenue:%{y:$.2s}",
# legendgroup=channel_name, showlegend=False)))
# for _d in weekly_spends_data:
# weekly_contribution_fig.add_trace(_d, row=1, col=1)
# for _d in weekly_sales_data:
# weekly_contribution_fig.add_trace(_d, row=1, col=2)
# weekly_contribution_fig.add_trace(go.Bar(x=x,
# y=scenario.constant + scenario.correction,
# name='Non Media',
# hovertemplate="Date:%{x}
Revenue:%{y:$.2s}"), row=1, col=2)
# weekly_contribution_fig.update_layout(barmode='stack', title='Channel contribuion by week', xaxis_title='Date')
# weekly_contribution_fig.update_xaxes(showgrid=False)
# weekly_contribution_fig.update_yaxes(showgrid=False)
# return weekly_contribution_fig
# @st.cache(allow_output_mutation=True)
# def create_channel_spends_sales_plot(channel):
# if channel is not None:
# x = channel.dates
# _spends = channel.actual_spends * channel.conversion_rate
# _sales = channel.actual_sales
# channel_sales_spends_fig = make_subplots(specs=[[{"secondary_y": True}]])
# channel_sales_spends_fig.add_trace(go.Bar(x=x, y=_sales,marker_color='#c1f7dc',name='Revenue', hovertemplate="Date:%{x}
Revenue:%{y:$.2s}"), secondary_y = False)
# channel_sales_spends_fig.add_trace(go.Scatter(x=x, y=_spends,line=dict(color='#005b96'),name='Spends',hovertemplate="Date:%{x}
Spend:%{y:$.2s}"), secondary_y = True)
# channel_sales_spends_fig.update_layout(xaxis_title='Date',yaxis_title='Revenue',yaxis2_title='Spends ($)',title='Channel spends and Revenue week wise')
# channel_sales_spends_fig.update_xaxes(showgrid=False)
# channel_sales_spends_fig.update_yaxes(showgrid=False)
# else:
# raw_df = st.session_state['raw_df']
# df = raw_df.sort_values(by='Date')
# x = df.Date
# scenario = class_from_dict(st.session_state['default_scenario_dict'])
# _sales = scenario.constant + scenario.correction
# channel_sales_spends_fig = make_subplots(specs=[[{"secondary_y": True}]])
# channel_sales_spends_fig.add_trace(go.Bar(x=x, y=_sales,marker_color='#c1f7dc',name='Revenue', hovertemplate="Date:%{x}
Revenue:%{y:$.2s}"), secondary_y = False)
# # channel_sales_spends_fig.add_trace(go.Scatter(x=x, y=_spends,line=dict(color='#15C39A'),name='Spends',hovertemplate="Date:%{x}
Spend:%{y:$.2s}"), secondary_y = True)
# channel_sales_spends_fig.update_layout(xaxis_title='Date',yaxis_title='Revenue',yaxis2_title='Spends ($)',title='Channel spends and Revenue week wise')
# channel_sales_spends_fig.update_xaxes(showgrid=False)
# channel_sales_spends_fig.update_yaxes(showgrid=False)
# return channel_sales_spends_fig
# Define a shared color palette
def create_contribution_pie():
color_palette = [
"#F3F3F0",
"#5E7D7E",
"#2FA1FF",
"#00EDED",
"#00EAE4",
"#304550",
"#EDEBEB",
"#7FBEFD",
"#003059",
"#A2F3F3",
"#E1D6E2",
"#B6B6B6",
]
total_contribution_fig = make_subplots(
rows=1,
cols=2,
subplot_titles=["Spends", "Revenue"],
specs=[[{"type": "pie"}, {"type": "pie"}]],
)
channels_list = [
"Paid Search",
"Ga will cid baixo risco",
"Digital tactic others",
"Fb la tier 1",
"Fb la tier 2",
"Paid social others",
"Programmatic",
"Kwai",
"Indicacao",
"Infleux",
"Influencer",
"Non Media",
]
# Assign colors from the limited palette to channels
colors_map = {
col: color_palette[i % len(color_palette)]
for i, col in enumerate(channels_list)
}
colors_map["Non Media"] = color_palette[
5
] # Assign fixed green color for 'Non Media'
# Hardcoded values for Spends and Revenue
spends_values = [0.5, 3.36, 1.1, 2.7, 2.7, 2.27, 70.6, 1, 1, 13.7, 1, 0]
revenue_values = [1, 4, 5, 3, 3, 2, 50.8, 1.5, 0.7, 13, 0, 16]
# Add trace for Spends pie chart
total_contribution_fig.add_trace(
go.Pie(
labels=[channel_name for channel_name in channels_list],
values=spends_values,
marker=dict(
colors=[colors_map[channel_name] for channel_name in channels_list]
),
hole=0.3,
),
row=1,
col=1,
)
# Add trace for Revenue pie chart
total_contribution_fig.add_trace(
go.Pie(
labels=[channel_name for channel_name in channels_list],
values=revenue_values,
marker=dict(
colors=[colors_map[channel_name] for channel_name in channels_list]
),
hole=0.3,
),
row=1,
col=2,
)
total_contribution_fig.update_traces(
textposition="inside", texttemplate="%{percent:.1%}"
)
total_contribution_fig.update_layout(
uniformtext_minsize=12,
title="Channel contribution",
uniformtext_mode="hide",
)
return total_contribution_fig
def create_contribuion_stacked_plot(scenario):
weekly_contribution_fig = make_subplots(
rows=1,
cols=2,
subplot_titles=["Spends", "Revenue"],
specs=[[{"type": "bar"}, {"type": "bar"}]],
)
raw_df = st.session_state["raw_df"]
df = raw_df.sort_values(by="Date")
x = df.Date
weekly_spends_data = []
weekly_sales_data = []
for i, channel_name in enumerate(st.session_state["channels_list"]):
color = color_palette[i % len(color_palette)]
weekly_spends_data.append(
go.Bar(
x=x,
y=scenario.channels[channel_name].actual_spends
* scenario.channels[channel_name].conversion_rate,
name=channel_name_formating(channel_name),
hovertemplate="Date:%{x}
Spend:%{y:$.2s}",
legendgroup=channel_name,
marker_color=color,
)
)
weekly_sales_data.append(
go.Bar(
x=x,
y=scenario.channels[channel_name].actual_sales,
name=channel_name_formating(channel_name),
hovertemplate="Date:%{x}
Revenue:%{y:$.2s}",
legendgroup=channel_name,
showlegend=False,
marker_color=color,
)
)
for _d in weekly_spends_data:
weekly_contribution_fig.add_trace(_d, row=1, col=1)
for _d in weekly_sales_data:
weekly_contribution_fig.add_trace(_d, row=1, col=2)
weekly_contribution_fig.add_trace(
go.Bar(
x=x,
y=scenario.constant + scenario.correction,
name="Non Media",
hovertemplate="Date:%{x}
Revenue:%{y:$.2s}",
marker_color=color_palette[-1],
),
row=1,
col=2,
)
weekly_contribution_fig.update_layout(
barmode="stack",
title="Channel contribution by week",
xaxis_title="Date",
)
weekly_contribution_fig.update_xaxes(showgrid=False)
weekly_contribution_fig.update_yaxes(showgrid=False)
return weekly_contribution_fig
def create_channel_spends_sales_plot(channel):
if channel is not None:
x = channel.dates
_spends = channel.actual_spends * channel.conversion_rate
_sales = channel.actual_sales
channel_sales_spends_fig = make_subplots(specs=[[{"secondary_y": True}]])
channel_sales_spends_fig.add_trace(
go.Bar(
x=x,
y=_sales,
marker_color=color_palette[
3
], # You can choose a color from the palette
name="Revenue",
hovertemplate="Date:%{x}
Revenue:%{y:$.2s}",
),
secondary_y=False,
)
channel_sales_spends_fig.add_trace(
go.Scatter(
x=x,
y=_spends,
line=dict(
color=color_palette[2]
), # You can choose another color from the palette
name="Spends",
hovertemplate="Date:%{x}
Spend:%{y:$.2s}",
),
secondary_y=True,
)
channel_sales_spends_fig.update_layout(
xaxis_title="Date",
yaxis_title="Revenue",
yaxis2_title="Spends ($)",
title="Channel spends and Revenue week-wise",
)
channel_sales_spends_fig.update_xaxes(showgrid=False)
channel_sales_spends_fig.update_yaxes(showgrid=False)
else:
raw_df = st.session_state["raw_df"]
df = raw_df.sort_values(by="Date")
x = df.Date
scenario = class_from_dict(st.session_state["default_scenario_dict"])
_sales = scenario.constant + scenario.correction
channel_sales_spends_fig = make_subplots(specs=[[{"secondary_y": True}]])
channel_sales_spends_fig.add_trace(
go.Bar(
x=x,
y=_sales,
marker_color=color_palette[
0
], # You can choose a color from the palette
name="Revenue",
hovertemplate="Date:%{x}
Revenue:%{y:$.2s}",
),
secondary_y=False,
)
channel_sales_spends_fig.update_layout(
xaxis_title="Date",
yaxis_title="Revenue",
yaxis2_title="Spends ($)",
title="Channel spends and Revenue week-wise",
)
channel_sales_spends_fig.update_xaxes(showgrid=False)
channel_sales_spends_fig.update_yaxes(showgrid=False)
return channel_sales_spends_fig
def format_numbers(value, n_decimals=1, include_indicator=True):
if value is None:
return None
_value = value if value < 1 else numerize(value, n_decimals)
if include_indicator:
return f"{CURRENCY_INDICATOR} {_value}"
else:
return f"{_value}"
def decimal_formater(num_string, n_decimals=1):
parts = num_string.split(".")
if len(parts) == 1:
return num_string + "." + "0" * n_decimals
else:
to_be_padded = n_decimals - len(parts[-1])
if to_be_padded > 0:
return num_string + "0" * to_be_padded
else:
return num_string
def channel_name_formating(channel_name):
name_mod = channel_name.replace("_", " ")
if name_mod.lower().endswith(" imp"):
name_mod = name_mod.replace("Imp", "Spend")
elif name_mod.lower().endswith(" clicks"):
name_mod = name_mod.replace("Clicks", "Spend")
return name_mod
def send_email(email, message):
s = smtplib.SMTP("smtp.gmail.com", 587)
s.starttls()
s.login("geethu4444@gmail.com", "jgydhpfusuremcol")
s.sendmail("geethu4444@gmail.com", email, message)
s.quit()
# if __name__ == "__main__":
# initialize_data()
#############################################################################################################
import os
import json
import streamlit as st
# Function to get panels names
def get_panels_names(file_selected):
raw_data_df = st.session_state["project_dct"]["current_media_performance"][
"model_outputs"
][file_selected]["raw_data"]
if "panel" in raw_data_df.columns:
panel = list(set(raw_data_df["panel"]))
elif "Panel" in raw_data_df.columns:
panel = list(set(raw_data_df["Panel"]))
else:
panel = []
return panel + ["aggregated"]
# Function to get metrics names
def get_metrics_names():
return list(
st.session_state["project_dct"]["current_media_performance"][
"model_outputs"
].keys()
)
# Function to load the original and modified rcs metadata files into dictionaries
def load_rcs_metadata_files():
original_data = st.session_state["project_dct"]["response_curves"][
"original_metadata_file"
]
modified_data = st.session_state["project_dct"]["response_curves"][
"modified_metadata_file"
]
return original_data, modified_data
# Function to format name
def name_formating(name):
# Replace underscores with spaces
name_mod = name.replace("_", " ")
# Capitalize the first letter of each word
name_mod = name_mod.title()
return name_mod
# Function to load the original and modified scenario metadata files into dictionaries
def load_scenario_metadata_files():
original_data = st.session_state["project_dct"]["scenario_planner"][
"original_metadata_file"
]
modified_data = st.session_state["project_dct"]["scenario_planner"][
"modified_metadata_file"
]
return original_data, modified_data
# Function to generate RCS data and store it as dictionary
def generate_rcs_data():
# Retrieve the list of all metric names from the specified directory
metrics_list = get_metrics_names()
# Dictionary to store RCS data for all metrics and their respective panels
all_rcs_data_original = {}
all_rcs_data_modified = {}
# Iterate over each metric in the metrics list
for metric in metrics_list:
# Retrieve the list of panel names from the current metric's Excel file
panel_list = get_panels_names(file_selected=metric)
# Check if rcs_data_modified exist
if (
st.session_state["project_dct"]["response_curves"]["modified_metadata_file"]
is not None
):
modified_data = st.session_state["project_dct"]["response_curves"][
"modified_metadata_file"
]
# Iterate over each panel in the panel list
for panel in panel_list:
# Initialize the original RCS data for the current panel and metric
rcs_dict_original, scenario = initialize_data(
panel=panel,
metrics=metric,
)
# Ensure the dictionary has the metric as a key for original data
if metric not in all_rcs_data_original:
all_rcs_data_original[metric] = {}
# Store the original RCS data under the corresponding panel for the current metric
all_rcs_data_original[metric][panel] = rcs_dict_original
# Ensure the dictionary has the metric as a key for modified data
if metric not in all_rcs_data_modified:
all_rcs_data_modified[metric] = {}
# Store the modified RCS data under the corresponding panel for the current metric
for channel in rcs_dict_original:
all_rcs_data_modified[metric][panel] = all_rcs_data_modified[
metric
].get(panel, {})
try:
updated_rcs_dict = modified_data[metric][panel][channel]
except:
updated_rcs_dict = {
"K": rcs_dict_original[channel]["K"],
"b": rcs_dict_original[channel]["b"],
"a": rcs_dict_original[channel]["a"],
"x0": rcs_dict_original[channel]["x0"],
}
all_rcs_data_modified[metric][panel][channel] = updated_rcs_dict
# Write the original RCS data
st.session_state["project_dct"]["response_curves"][
"original_metadata_file"
] = all_rcs_data_original
# Write the modified RCS data
st.session_state["project_dct"]["response_curves"][
"modified_metadata_file"
] = all_rcs_data_modified
# Function to generate scenario data and store it as dictionary
def generate_scenario_data():
# Retrieve the list of all metric names from the specified directory
metrics_list = get_metrics_names()
# Dictionary to store scenario data for all metrics and their respective panels
all_scenario_data_original = {}
all_scenario_data_modified = {}
# Iterate over each metric in the metrics list
for metric in metrics_list:
# Retrieve the list of panel names from the current metric's Excel file
panel_list = get_panels_names(metric)
# Check if scenario_data_modified exist
if (
st.session_state["project_dct"]["scenario_planner"][
"modified_metadata_file"
]
is not None
):
modified_data = st.session_state["project_dct"]["scenario_planner"][
"modified_metadata_file"
]
# Iterate over each panel in the panel list
for panel in panel_list:
# Initialize the original scenario data for the current panel and metric
rcs_dict_original, scenario = initialize_data(
panel=panel,
metrics=metric,
)
# Ensure the dictionary has the metric as a key for original data
if metric not in all_scenario_data_original:
all_scenario_data_original[metric] = {}
# Store the original scenario data under the corresponding panel for the current metric
all_scenario_data_original[metric][panel] = class_convert_to_dict(scenario)
# Ensure the dictionary has the metric as a key for modified data
if metric not in all_scenario_data_modified:
all_scenario_data_modified[metric] = {}
# Store the modified scenario data under the corresponding panel for the current metric
try:
all_scenario_data_modified[metric][panel] = modified_data[metric][panel]
except:
all_scenario_data_modified[metric][panel] = class_convert_to_dict(
scenario
)
# Write the original scenario data
st.session_state["project_dct"]["scenario_planner"][
"original_metadata_file"
] = all_scenario_data_original
# Write the modified scenario data
st.session_state["project_dct"]["scenario_planner"][
"modified_metadata_file"
] = all_scenario_data_modified
#############################################################################################################