题目描述
Marton的朋友Cero有一个包含N个正整数的数组。开始时,Cero在黑板上写上第一个数字,然后,他将第二个数字写在第一个数字的左边或右边,之后,他将第三个数字写在目前为止写下的所有数字的左边或右边,以此类推。当他写下全部N个数字后,会形成一个新的数组。
●Marton想知道新数组的最长严格递增子序列的长度。
●Marton还想知道这种最长严格递增子序列的数量。
更确切的说,如果所有能构建出的新数组中最长严格递增子序列的最长长度为M,则想Marton知道所有可以构建的每个新数组中长度为M的最长严格递增子序列的数目的总和。如果新数组使用不同的顺序构建,则称为不同的新数组。对于同一个新数组,如果两个最长严格递增子序列在至少一个位置上不同,则称为两个不同的最长严格递增子序列。 考虑到这样的子序列的数目非常大,只需求出答案对109+7取模的结果。 Cero要求你来回答Marton的两个问题。
输入
第一行输入包含整数N(1≤N≤2*1e5)。表示数组中有多少个数。 第二行输入包含N个正整数Ai(1≤Ai≤1e9)。表示Cero原本的数组。
输出
输出两个整数,中间用空格分隔。分别表示最长严格递增子序列的长度和这种长度的最长严格递增子序列的的数目。
样例输入
2
1 1
样例输出
1 4
解题思路
考场顿时崩溃,什么垃圾题目
于是写了暴力,先dfs搞新数组,然后N方的最长上升子序列,于是骗了一些分,舒服。
其实我们先看一个点i,求出以它为左端点的最长上升与最长下降序列,你可能会问,这有啥用?
由于可以左右放,我们把那个最长下降的序列放到i的左边,诶,好像这样组合成了一个含i的最长上升序列。
绝对最长,可以证明…
那么,我们可以求出最长上升子序列的长度值了,那么如何求数量。
只需要求出现次数即可。
嗯,题解写得差不多了。代码长了点,但其实很好理解。其中有一个关键点,那就是上升和下降的树状数组写法不同。
#include<cstdio>
#include<cstring>
#include<cmath>
#include<iostream>
#include<vector>
#include<queue>
#include<algorithm>
#define mod 1000000007
using namespace std;
struct node{
int num;
long long tot;
node(){
num = 0,tot = 0;
}
}tree[2][200005];
node dp[200005],dp1[200005];
int n,a[200005],b[200005],cnt,len;
long long ans;
long long qkpow(long long x,long long y){
long long ans = 1;
while (y){
if (y % 2 == 1)
ans = ans * x % mod;
x = x * x % mod;
y = y / 2;
}
return ans;
}
int lowbit(int x){
return x & (-x);
}
node check(node a,node b){
if (a.num == b.num)
a.tot = ( a.tot + b.tot ) % mod;
else if (a.num < b.num)
a = b;
return a;
}
void insert(int k,int pos,node x){
while (k <= cnt && k > 0){
tree[pos][k] = check(tree[pos][k],x);
k = k + (pos?-1:1) * lowbit(k);
}
}
node find(int k,int pos){
node ret;
while (k <= cnt && k > 0){
ret = check(tree[pos][k],ret);
k = k + (pos?1:-1) * lowbit(k);
}
return ret;
}
int main(){
//freopen("zoltan.in","r",stdin);
//freopen("zoltan.out","w",stdout);
scanf ("%d",&n);
for (int i = 1;i <= n;i ++){
scanf ("%d",&a[i]);
b[i] = a[i];
}
sort(b + 1,b + 1 + n);
cnt = unique(b + 1,b + 1 + n) - b - 1;
for (int i = 1;i <= n;i ++)
a[i] = lower_bound(b + 1,b + 1 + cnt,a[i]) - b;
for (int i = n;i >= 1;i --){
node p1 = find(a[i] + 1,1),s;
s.num = s.tot = 1;
if (!p1.num){
dp[i] = s;
insert(a[i],1,s);
}
else {
p1.num ++;
dp[i] = p1;
insert(a[i],1,p1);
}
}
for (int i = n;i >= 1;i --){
node p1 = find(a[i] - 1,0),s;
s.num = s.tot = 1;
if (!p1.num){
dp1[i] = s;
insert(a[i],0,s);
}
else {
p1.num ++;
dp1[i] = p1;
insert(a[i],0,p1);
}
}
for (int i = 1;i <= n;i ++)
len = max(len,dp[i].num + dp1[i].num - 1);
for (int i = 1;i <= n;i ++){
if (dp[i].num + dp1[i].num - 1 == len)
ans = (ans + (dp[i].tot * dp1[i].tot) % mod * qkpow(2,n - len) % mod) %mod;
}
printf("%d %lld",len,ans);
}