#include<bits/stdc++.h>
using namespace std;
#define ll long long
ll c[5022],dp[10001][3],r[10001],home[10001];
ll n,m,cnt;
struct node{
ll next,to;
}e[20001];
void add(ll a,ll b){
e[++cnt].to=n;
e[cnt].next=home[a];
home[a]=cnt;
}
void dfs(ll p,ll fa){
if(p<=n){
dp[p][0]=0x3f3f3f3f3f3f;
dp[p][1]=c[p]==0?1:0x3f3f3f3f3f3f;
dp[p][2]=c[p]==1?1:0x3f3f3f3f3f3f;
return ;
}
for(ll i=home[p];i;i=e[i].next){
ll t=e[i].to;
if(t==fa) continue;
dfs(t,p);
dp[p][0]+=min(dp[t][1],dp[t][2]);
dp[p][1]+=dp[t][1]-1;
dp[p][2]+=dp[t][2]-1;
}
dp[p][1]++;
dp[p][2]++;
}
int main(){
scanf("%lld %lld",&m,&n);
queue<ll> q;
for(ll i=1;i<=n;i++) scanf("%lld",&c[i]);
for(ll i=1;i<m;i++){
ll a,b;
scanf("%lld %lld",&a,&b);
add(a,b);
add(b,a);
r[a]++;
r[b]++;
if(r[a]==2) q.push(a);
if(r[b]==2) q.push(b);
}
ll ans=0x3f3f3f3f3f3f3f3f;
while(!q.empty()){
ll root=q.front();
q.pop();
for(ll i=1;i<=m;i++){
dp[i][1]=dp[i][0]=dp[i][2]=0;
}
dfs(root,0);
ans=min(ans,dp[root][0]);
}
printf("%lld",ans);
return 0;
}